@@ -8126,23 +8126,51 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
81268126 dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
81278127}
81288128
8129- static void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale,
8130- const sycl::nd_item<3> &item_ct1, float *buf) {
8129+
8130+ template <bool vals_smem, int ncols_template, int block_size_template>
8131+ static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
8132+ const int nrows_y, const float scale, const float max_bias, const float m0,
8133+ const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
8134+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
8135+
81318136 const int tid = item_ct1.get_local_id(2);
81328137 const int rowx = item_ct1.get_group(2);
81338138 const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
81348139
8135- const int block_size = item_ct1.get_local_range(2);
8140+ const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template ;
81368141
81378142 const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
81388143 const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
81398144
8145+ float slope = 0.0f;
8146+
8147+ // ALiBi
8148+ if (max_bias > 0.0f) {
8149+ const uint32_t h = rowx/nrows_y; // head index
8150+
8151+ const float base = h < n_head_log2 ? m0 : m1;
8152+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
8153+
8154+ slope = sycl::pow(base, float(exp));
8155+ }
8156+
8157+ float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
81408158 float max_val = -INFINITY;
81418159
8142- for (int col = tid; col < ncols; col += block_size) {
8160+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
8161+ const int col = col0 + tid;
8162+
8163+ if (ncols_template == 0 && col >= ncols) {
8164+ break;
8165+ }
8166+
81438167 const int ix = rowx*ncols + col;
81448168 const int iy = rowy*ncols + col;
8145- max_val = sycl::max(max_val, x[ix] * scale + (y ? y[iy] : 0.0f));
8169+
8170+ const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
8171+
8172+ vals[col] = val;
8173+ max_val = sycl::max(max_val, val);
81468174 }
81478175
81488176 // find the max value in the block
@@ -8151,44 +8179,29 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
81518179 if (warp_id == 0) {
81528180 buf[lane_id] = -INFINITY;
81538181 }
8154- /*
8155- DPCT1118:12: SYCL group functions and algorithms must be encountered in
8156- converged control flow. You may need to adjust the code.
8157- */
8158- /*
8159- DPCT1065:60: Consider replacing sycl::nd_item::barrier() with
8160- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8161- better performance if there is no access to global memory.
8162- */
8163- item_ct1.barrier();
8182+ item_ct1.barrier(sycl::access::fence_space::local_space);
81648183
81658184 if (lane_id == 0) {
81668185 buf[warp_id] = max_val;
81678186 }
8168- /*
8169- DPCT1118:13: SYCL group functions and algorithms must be encountered in
8170- converged control flow. You may need to adjust the code.
8171- */
8172- /*
8173- DPCT1065:61: Consider replacing sycl::nd_item::barrier() with
8174- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8175- better performance if there is no access to global memory.
8176- */
8177- item_ct1.barrier();
8187+ item_ct1.barrier(sycl::access::fence_space::local_space);
81788188
81798189 max_val = buf[lane_id];
81808190 max_val = warp_reduce_max(max_val, item_ct1);
81818191 }
81828192
81838193 float tmp = 0.f;
81848194
8185- for (int col = tid; col < ncols; col += block_size) {
8186- const int ix = rowx*ncols + col;
8187- const int iy = rowy*ncols + col;
8188- const float val =
8189- sycl::native::exp((x[ix] * scale + (y ? y[iy] : 0.0f)) - max_val);
8195+ #pragma unroll
8196+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
8197+ const int col = col0 + tid;
8198+ if (ncols_template == 0 && col >= ncols) {
8199+ break;
8200+ }
8201+
8202+ const float val = sycl::native::exp(vals[col] - max_val);
81908203 tmp += val;
8191- dst[ix ] = val;
8204+ vals[col ] = val;
81928205 }
81938206
81948207 // find the sum of exps in the block
@@ -8197,40 +8210,29 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
81978210 if (warp_id == 0) {
81988211 buf[lane_id] = 0.f;
81998212 }
8200- /*
8201- DPCT1118:14: SYCL group functions and algorithms must be encountered in
8202- converged control flow. You may need to adjust the code.
8203- */
8204- /*
8205- DPCT1065:62: Consider replacing sycl::nd_item::barrier() with
8206- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8207- better performance if there is no access to global memory.
8208- */
8209- item_ct1.barrier();
8213+ item_ct1.barrier(sycl::access::fence_space::local_space);
82108214
82118215 if (lane_id == 0) {
82128216 buf[warp_id] = tmp;
82138217 }
8214- /*
8215- DPCT1118:15: SYCL group functions and algorithms must be encountered in
8216- converged control flow. You may need to adjust the code.
8217- */
8218- /*
8219- DPCT1065:63: Consider replacing sycl::nd_item::barrier() with
8220- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8221- better performance if there is no access to global memory.
8222- */
8223- item_ct1.barrier();
8218+ item_ct1.barrier(sycl::access::fence_space::local_space);
82248219
82258220 tmp = buf[lane_id];
82268221 tmp = warp_reduce_sum(tmp, item_ct1);
82278222 }
82288223
8229- const float inv_tmp = 1.f / tmp;
8224+ const float inv_sum = 1.f / tmp;
82308225
8231- for (int col = tid; col < ncols; col += block_size) {
8232- const int i = rowx*ncols + col;
8233- dst[i] *= inv_tmp;
8226+ #pragma unroll
8227+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
8228+ const int col = col0 + tid;
8229+
8230+ if (ncols_template == 0 && col >= ncols) {
8231+ return;
8232+ }
8233+
8234+ const int idst = rowx*ncols + col;
8235+ dst[idst] = vals[col] * inv_sum;
82348236 }
82358237}
82368238
@@ -10867,37 +10869,98 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
1086710869 });
1086810870}
1086910871
10870- static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
10871- const int ncols_x, const int nrows_x,
10872- const int nrows_y, const float scale,
10873- dpct::queue_ptr stream) {
10874- int nth = WARP_SIZE;
10875- while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
10876- const sycl::range<3> block_dims(1, 1, nth);
10877- const sycl::range<3> block_nums(1, 1, nrows_x);
10878- /*
10879- DPCT1049:46: The work-group size passed to the SYCL kernel may exceed the
10880- limit. To get the device limit, query info::device::max_work_group_size.
10881- Adjust the work-group size if needed.
10882- */
10872+ template <bool vals_smem, int ncols_template, int block_size_template>
10873+ static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
10874+ const int nrows_y, const float scale, const float max_bias, const float m0,
10875+ const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
10876+ const size_t n_local_scratch, dpct::queue_ptr stream) {
1088310877 stream->submit([&](sycl::handler &cgh) {
10884- /*
10885- DPCT1101:96: 'SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE' expression was
10886- replaced with a value. Modify the code to use the original expression,
10887- provided in comments, if it is correct.
10888- */
10889- sycl::local_accessor<float, 1> buf_acc_ct1(
10890- sycl::range<1>(32 /*SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE*/), cgh);
10878+ sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
1089110879
1089210880 cgh.parallel_for(
1089310881 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1089410882 [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
10895- soft_max_f32(x, y, dst, ncols_x, nrows_y, scale, item_ct1,
10896- buf_acc_ct1.get_pointer());
10883+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
10884+ nrows_y, scale, max_bias, m0,
10885+ m1, n_head_log2, item_ct1,
10886+ local_buf_acc.get_pointer());
1089710887 });
1089810888 });
1089910889}
1090010890
10891+ static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
10892+ float * dst, const int ncols_x, const int nrows_x,
10893+ const int nrows_y, const float scale, const float max_bias,
10894+ dpct::queue_ptr stream) {
10895+ int nth = WARP_SIZE;
10896+ while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
10897+ const sycl::range<3> block_dims(1, 1, nth);
10898+ const sycl::range<3> block_nums(1, 1, nrows_x);
10899+ const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
10900+ static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
10901+
10902+ const uint32_t n_head_kv = nrows_x/nrows_y;
10903+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
10904+
10905+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
10906+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
10907+
10908+ const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
10909+ if (n_local_scratch*sizeof(float) < local_mem_size) {
10910+ switch (ncols_x) {
10911+ case 32:
10912+ soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10913+ max_bias, m0, m1, n_head_log2, block_nums,
10914+ block_dims, n_local_scratch, stream);
10915+ break;
10916+ case 64:
10917+ soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10918+ max_bias, m0, m1, n_head_log2, block_nums,
10919+ block_dims, n_local_scratch, stream);
10920+ break;
10921+ case 128:
10922+ soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10923+ max_bias, m0, m1, n_head_log2, block_nums,
10924+ block_dims, n_local_scratch, stream);
10925+ break;
10926+ case 256:
10927+ soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10928+ max_bias, m0, m1, n_head_log2, block_nums,
10929+ block_dims, n_local_scratch, stream);
10930+ break;
10931+ case 512:
10932+ soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10933+ max_bias, m0, m1, n_head_log2, block_nums,
10934+ block_dims, n_local_scratch, stream);
10935+ break;
10936+ case 1024:
10937+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10938+ max_bias, m0, m1, n_head_log2, block_nums,
10939+ block_dims, n_local_scratch, stream);
10940+ break;
10941+ case 2048:
10942+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10943+ max_bias, m0, m1, n_head_log2, block_nums,
10944+ block_dims, n_local_scratch, stream);
10945+ break;
10946+ case 4096:
10947+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10948+ max_bias, m0, m1, n_head_log2, block_nums,
10949+ block_dims, n_local_scratch, stream);
10950+ break;
10951+ default:
10952+ soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10953+ max_bias, m0, m1, n_head_log2, block_nums,
10954+ block_dims, n_local_scratch, stream);
10955+ break;
10956+ }
10957+ } else {
10958+ soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10959+ max_bias, m0, m1, n_head_log2, block_nums,
10960+ block_dims, WARP_SIZE, stream);
10961+ }
10962+ }
10963+
1090110964template <typename T>
1090210965static void im2col_sycl(const float *x, T *dst, int IW, int IH,
1090310966 int OW, int OH, int KW, int KH, int IC,
@@ -12435,14 +12498,35 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
1243512498
1243612499 const int64_t ne00 = src0->ne[0];
1243712500 const int64_t nrows_x = ggml_nrows(src0);
12438- const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1 ;
12501+ const int64_t nrows_y = src0->ne[1] ;
1243912502
1244012503 float scale = 1.0f;
12441- memcpy(&scale, dst->op_params, sizeof(float)) ;
12504+ float max_bias = 0.0f ;
1244212505
12443- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
12506+ memcpy(&scale, dst->op_params + 0, sizeof(float));
12507+ memcpy(&max_bias, dst->op_params + 1, sizeof(float));
1244412508
12445- (void) dst;
12509+ // positions tensor
12510+ float * src2_dd = nullptr;
12511+ sycl_pool_alloc<float> src2_f;
12512+
12513+ ggml_tensor * src2 = dst->src[2];
12514+ const bool use_src2 = src2 != nullptr;
12515+
12516+ if (use_src2) {
12517+ const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
12518+
12519+ if (src2_on_device) {
12520+ ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
12521+ src2_dd = (float *) src2_extra->data_device[g_main_device];
12522+ } else {
12523+ src2_dd = src2_f.alloc(ggml_nelements(src2));
12524+ SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
12525+ }
12526+ }
12527+
12528+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
12529+ nrows_x, nrows_y, scale, max_bias, main_stream);
1244612530}
1244712531
1244812532inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1,
0 commit comments