@@ -197,6 +197,7 @@ struct sycl_device_info {
197197 int cc; // compute capability
198198 // int nsm; // number of streaming multiprocessors
199199 // size_t smpb; // max. shared memory per block
200+ size_t smpbo; // max. shared memory per block (with opt-in)
200201 bool vmm; // virtual memory support
201202 size_t total_vram;
202203 // sycl_hw_info hw_info; \\ device id and aarch, currently not used
@@ -416,13 +417,6 @@ static __dpct_inline__ float warp_reduce_sum(float x,
416417 const sycl::nd_item<3 >& item_ct1) {
417418#pragma unroll
418419 for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
419- /*
420- DPCT1096:98: The right-most dimension of the work-group used in the SYCL
421- kernel that calls this function may be less than "32". The function
422- "dpct::permute_sub_group_by_xor" may return an unexpected result on the
423- CPU device. Modify the size of the work-group to ensure that the value
424- of the right-most dimension is a multiple of "32".
425- */
426420 x += dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), x, mask);
427421 }
428422 return x;
@@ -440,17 +434,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
440434 return a;
441435}
442436
437+ template <int width = WARP_SIZE>
438+ static __dpct_inline__ int warp_reduce_sum (int x) {
439+ return sycl::reduce_over_group (
440+ sycl::ext::oneapi::this_work_item::get_sub_group (), x, sycl::plus<>());
441+ }
442+
443+ template <int width = WARP_SIZE>
444+ static __dpct_inline__ float warp_reduce_sum (float x) {
445+ #pragma unroll
446+ for (int offset = width / 2 ; offset > 0 ; offset >>= 1 ) {
447+ x += dpct::permute_sub_group_by_xor (
448+ sycl::ext::oneapi::this_work_item::get_sub_group (), x, offset, width);
449+ }
450+ return x;
451+ }
452+
453+ template <int width = WARP_SIZE>
454+ static __dpct_inline__ sycl::float2 warp_reduce_sum (sycl::float2 a) {
455+ #pragma unroll
456+ for (int offset = width / 2 ; offset > 0 ; offset >>= 1 ) {
457+ a.x () += dpct::permute_sub_group_by_xor (
458+ sycl::ext::oneapi::this_work_item::get_sub_group (), a.x (), offset,
459+ width);
460+ a.y () += dpct::permute_sub_group_by_xor (
461+ sycl::ext::oneapi::this_work_item::get_sub_group (), a.y (), offset,
462+ width);
463+ }
464+ return a;
465+ }
466+
467+ template <int width = WARP_SIZE>
468+ static __dpct_inline__ sycl::half2 warp_reduce_sum (sycl::half2 a) {
469+ #pragma unroll
470+ for (int offset = width / 2 ; offset > 0 ; offset >>= 1 ) {
471+ a = a + dpct::permute_sub_group_by_xor (
472+ sycl::ext::oneapi::this_work_item::get_sub_group (), a, offset,
473+ width);
474+ }
475+ return a;
476+ }
477+
478+ static constexpr int ggml_sycl_get_physical_warp_size () {
479+ // todo: for old iGPU + dGPU case, need to be changed.
480+ return WARP_SIZE;
481+ }
482+
483+ template <int width = WARP_SIZE>
484+ static __dpct_inline__ float warp_reduce_max (float x) {
485+ #pragma unroll
486+ for (int offset = width / 2 ; offset > 0 ; offset >>= 1 ) {
487+ x = sycl::fmax (x, dpct::permute_sub_group_by_xor (
488+ sycl::ext::oneapi::this_work_item::get_sub_group (), x,
489+ offset, width));
490+ }
491+ return x;
492+ }
493+
443494static __dpct_inline__ float warp_reduce_max (float x,
444495 const sycl::nd_item<3 >& item_ct1) {
445496#pragma unroll
446497 for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
447- /*
448- DPCT1096:97: The right-most dimension of the work-group used in the SYCL
449- kernel that calls this function may be less than "32". The function
450- "dpct::permute_sub_group_by_xor" may return an unexpected result on the
451- CPU device. Modify the size of the work-group to ensure that the value
452- of the right-most dimension is a multiple of "32".
453- */
454498 x = sycl::fmax (x, dpct::permute_sub_group_by_xor (
455499 item_ct1.get_sub_group (), x, mask));
456500 }
@@ -558,4 +602,18 @@ struct scope_op_debug_print {
558602 std::string_view func_suffix;
559603};
560604
605+ static __dpct_inline__ float get_alibi_slope (const float max_bias,
606+ const uint32_t h,
607+ const uint32_t n_head_log2,
608+ const float m0,
609+ const float m1) {
610+ if (max_bias <= 0 .0f ) {
611+ return 1 .0f ;
612+ }
613+ const float base = h < n_head_log2 ? m0 : m1;
614+ const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
615+
616+ return dpct::pow (base, exph);
617+ }
618+
561619#endif // GGML_SYCL_COMMON_HPP
0 commit comments