@@ -244,26 +244,24 @@ static void launch_soft_max_kernels(const float * x,
244244 sycl::range<1 >(nbytes_shared), cgh);
245245
246246 cgh.parallel_for (
247- sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=
248- ](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
249- soft_max_f32<true , 0 , 0 >(
250- x, mask, sinks, dst, p,
251- dpct_local_acc_ct1
252- .get_multi_ptr <sycl::access::decorated::no>()
253- .get ());
254- GGML_UNUSED (item_ct1);
255- });
247+ sycl::nd_range<3 >(block_nums * block_dims, block_dims),
248+ [=](sycl::nd_item<3 > item_ct1)
249+ [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
250+ soft_max_f32<true , 0 , 0 >(
251+ x, mask, sinks, dst, p,
252+ dpct_local_acc_ct1
253+ .get_multi_ptr <sycl::access::decorated::no>()
254+ .get ());
255+ GGML_UNUSED (item_ct1);
256+ });
256257 });
257258}
258259
259260template <typename T>
260- static void soft_max_f32_sycl (const float * x,
261- const T * mask,
262- const float * sinks,
263- float * dst,
264- const soft_max_params & params,
265- dpct::queue_ptr stream,
266- int device) {
261+ static void soft_max_f32_sycl (const float *x, const T *mask,
262+ const float *sinks, float *dst,
263+ const soft_max_params ¶ms,
264+ dpct::queue_ptr stream, int device) {
267265 int nth = WARP_SIZE;
268266 int max_block_size = ggml_sycl_info ().max_work_group_sizes [device];
269267 const int64_t ncols_x = params.ncols ;
@@ -273,7 +271,8 @@ static void soft_max_f32_sycl(const float * x,
273271
274272 const dpct::dim3 block_dims (nth, 1 , 1 );
275273 const dpct::dim3 block_nums (params.ne01 , params.ne02 , params.ne03 );
276- const size_t nbytes_shared = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof (float );
274+ const size_t nbytes_shared =
275+ (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof (float );
277276
278277 const int id = get_current_device_id ();
279278 const size_t smpbo = ggml_sycl_info ().devices [id].smpbo ;
0 commit comments