@@ -892,134 +892,6 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
892892}
893893
894894
895- template <bool vals_smem, int ncols_template, int block_size_template>
896- static void soft_max_f32 (const float * x, const float * mask, float * dst, const int ncols_par,
897- const int nrows_y, const float scale, const float max_bias, const float m0,
898- const float m1, uint32_t n_head_log2, const sycl::nd_item<3 > &item_ct1, float *buf) {
899- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
900-
901- const int tid = item_ct1.get_local_id (2 );
902- const int rowx = item_ct1.get_group (2 );
903- const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
904-
905- const int block_size = block_size_template == 0 ? item_ct1.get_local_range (2 ) : block_size_template;
906-
907- const int warp_id = item_ct1.get_local_id (2 ) / WARP_SIZE;
908- const int lane_id = item_ct1.get_local_id (2 ) % WARP_SIZE;
909- const int nthreads = block_size;
910- const int nwarps = nthreads / WARP_SIZE;
911- int nreduce = nwarps / WARP_SIZE;
912-
913-
914- float slope = 1 .0f ;
915-
916- // ALiBi
917- if (max_bias > 0 .0f ) {
918- const uint32_t h = rowx/nrows_y; // head index
919-
920- const float base = h < n_head_log2 ? m0 : m1;
921- const int exp = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
922-
923- slope = sycl::pow (base, float (exp));
924- }
925-
926- float *vals = vals_smem ? buf + std::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
927- float max_val = -INFINITY;
928-
929- for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
930- const int col = col0 + tid;
931-
932- if (ncols_template == 0 && col >= ncols) {
933- break ;
934- }
935-
936- const int ix = rowx*ncols + col;
937- const int iy = rowy*ncols + col;
938-
939- const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0 .0f );
940-
941- vals[col] = val;
942- max_val = sycl::max (max_val, val);
943- }
944-
945- // find the max value in the block
946- max_val = warp_reduce_max (max_val, item_ct1);
947- if (block_size > WARP_SIZE) {
948- if (warp_id == 0 ) {
949- buf[lane_id] = -INFINITY;
950- for (size_t i = 1 ; i < nreduce; i += 1 )
951- buf[lane_id + i * WARP_SIZE] = -INFINITY;
952-
953- }
954- item_ct1.barrier (sycl::access::fence_space::local_space);
955-
956- if (lane_id == 0 ) {
957- buf[warp_id] = max_val;
958- }
959- item_ct1.barrier (sycl::access::fence_space::local_space);
960-
961- max_val = buf[lane_id];
962- for (size_t i = 1 ; i < nreduce; i += 1 )
963- {
964- max_val = std::max (max_val, buf[lane_id + i * WARP_SIZE]);
965- }
966-
967- max_val = warp_reduce_max (max_val, item_ct1);
968- }
969-
970- float tmp = 0 .f ;
971-
972- #pragma unroll
973- for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
974- const int col = col0 + tid;
975- if (ncols_template == 0 && col >= ncols) {
976- break ;
977- }
978-
979- const float val = sycl::native::exp (vals[col] - max_val);
980- tmp += val;
981- vals[col] = val;
982- }
983-
984- // find the sum of exps in the block
985- tmp = warp_reduce_sum (tmp, item_ct1);
986- if (block_size > WARP_SIZE) {
987- item_ct1.barrier (sycl::access::fence_space::local_space);
988- if (warp_id == 0 ) {
989- buf[lane_id] = 0 .f ;
990- for (size_t i = 1 ; i < nreduce; i += 1 )
991- buf[lane_id + i * WARP_SIZE] = 0 .f ;
992-
993- }
994- item_ct1.barrier (sycl::access::fence_space::local_space);
995-
996- if (lane_id == 0 ) {
997- buf[warp_id] = tmp;
998- }
999- item_ct1.barrier (sycl::access::fence_space::local_space);
1000-
1001- tmp = buf[lane_id];
1002- for (size_t i = 1 ; i < nreduce; i += 1 )
1003- {
1004- tmp += buf[lane_id + i * WARP_SIZE];
1005- }
1006- tmp = warp_reduce_sum (tmp, item_ct1);
1007- }
1008-
1009- const float inv_sum = 1 .f / tmp;
1010-
1011- #pragma unroll
1012- for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
1013- const int col = col0 + tid;
1014-
1015- if (ncols_template == 0 && col >= ncols) {
1016- return ;
1017- }
1018-
1019- const int idst = rowx*ncols + col;
1020- dst[idst] = vals[col] * inv_sum;
1021- }
1022- }
1023895
1024896static void scale_f32 (const float * x, float * dst, const float scale, const int k,
1025897 const sycl::nd_item<3 > &item_ct1) {
@@ -1908,105 +1780,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
19081780 });
19091781}
19101782
1911- template <bool vals_smem, int ncols_template, int block_size_template>
1912- static void soft_max_f32_submitter (const float * x, const float * mask, float * dst, const int ncols_par,
1913- const int nrows_y, const float scale, const float max_bias, const float m0,
1914- const float m1, uint32_t n_head_log2, sycl::range<3 > block_nums, sycl::range<3 > block_dims,
1915- const size_t n_local_scratch, queue_ptr stream) {
1916- stream->submit ([&](sycl::handler &cgh) {
1917- sycl::local_accessor<float , 1 > local_buf_acc (n_local_scratch, cgh);
19181783
1919- cgh.parallel_for (
1920- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
1921- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE)]] {
1922- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
1923- nrows_y, scale, max_bias, m0,
1924- m1, n_head_log2, item_ct1,
1925- local_buf_acc.get_pointer ());
1926- });
1927- });
1928- }
1929-
1930- static void soft_max_f32_sycl (const float * x, const float * mask,
1931- float * dst, const int ncols_x, const int nrows_x,
1932- const int nrows_y, const float scale, const float max_bias,
1933- queue_ptr stream, int device_id) {
1934- int nth = WARP_SIZE;
1935- int max_block_size = ggml_sycl_info ().work_group_size (device_id);
1936- while (nth < ncols_x && nth < max_block_size) nth *= 2 ;
1937- if (nth>max_block_size) nth = max_block_size;
1938-
1939- const sycl::range<3 > block_dims (1 , 1 , nth);
1940- const sycl::range<3 > block_nums (1 , 1 , nrows_x);
1941- const size_t n_local_scratch = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE);
1942-
1943- const uint32_t n_head_kv = nrows_x/nrows_y;
1944- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head_kv));
1945-
1946- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
1947- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
1948-
1949- const size_t local_mem_size = stream->get_device ().get_info <sycl::info::device::local_mem_size>();
1950- if (n_local_scratch*sizeof (float ) < local_mem_size) {
1951- if (ncols_x > max_block_size) {
1952- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
1953- max_bias, m0, m1, n_head_log2, block_nums,
1954- block_dims, n_local_scratch, stream);
1955- return ;
1956- }
1957- switch (ncols_x) {
1958- case 32 :
1959- soft_max_f32_submitter<true , 32 , 32 >(x, mask, dst, ncols_x, nrows_y, scale,
1960- max_bias, m0, m1, n_head_log2, block_nums,
1961- block_dims, n_local_scratch, stream);
1962- break ;
1963- case 64 :
1964- soft_max_f32_submitter<true , 64 , 64 >(x, mask, dst, ncols_x, nrows_y, scale,
1965- max_bias, m0, m1, n_head_log2, block_nums,
1966- block_dims, n_local_scratch, stream);
1967- break ;
1968- case 128 :
1969- soft_max_f32_submitter<true , 128 , 128 >(x, mask, dst, ncols_x, nrows_y, scale,
1970- max_bias, m0, m1, n_head_log2, block_nums,
1971- block_dims, n_local_scratch, stream);
1972- break ;
1973- case 256 :
1974- soft_max_f32_submitter<true , 256 , 256 >(x, mask, dst, ncols_x, nrows_y, scale,
1975- max_bias, m0, m1, n_head_log2, block_nums,
1976- block_dims, n_local_scratch, stream);
1977- break ;
1978- case 512 :
1979- soft_max_f32_submitter<true , 512 , 512 >(x, mask, dst, ncols_x, nrows_y, scale,
1980- max_bias, m0, m1, n_head_log2, block_nums,
1981- block_dims, n_local_scratch, stream);
1982- break ;
1983- case 1024 :
1984- soft_max_f32_submitter<true , 1024 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
1985- max_bias, m0, m1, n_head_log2, block_nums,
1986- block_dims, n_local_scratch, stream);
1987- break ;
1988- case 2048 :
1989- soft_max_f32_submitter<true , 2048 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
1990- max_bias, m0, m1, n_head_log2, block_nums,
1991- block_dims, n_local_scratch, stream);
1992- break ;
1993- case 4096 :
1994- soft_max_f32_submitter<true , 4096 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
1995- max_bias, m0, m1, n_head_log2, block_nums,
1996- block_dims, n_local_scratch, stream);
1997- break ;
1998- default :
1999- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
2000- max_bias, m0, m1, n_head_log2, block_nums,
2001- block_dims, n_local_scratch, stream);
2002- break ;
2003- }
2004- } else {
2005- soft_max_f32_submitter<false , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
2006- max_bias, m0, m1, n_head_log2, block_nums,
2007- block_dims, WARP_SIZE, stream);
2008- }
2009- }
20101784
20111785template <typename T>
20121786static void im2col_sycl (const float *x, T *dst, int IW, int IH,
@@ -2865,32 +2639,7 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const gg
28652639 (void ) src1_dd;
28662640}
28672641
2868- inline void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2869- const ggml_tensor *src1, ggml_tensor *dst,
2870- const float *src0_dd, const float *src1_dd,
2871- float *dst_dd,
2872- const queue_ptr &main_stream) {
2873-
2874- GGML_ASSERT (src0->type == GGML_TYPE_F32);
2875- GGML_ASSERT ( dst->type == GGML_TYPE_F32);
2876-
2877- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
2878- #pragma message("ref: https:// github.com/ggerganov/llama.cpp/pull/5021")
2879- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
28802642
2881- const int64_t ne00 = src0->ne [0 ];
2882- const int64_t nrows_x = ggml_nrows (src0);
2883- const int64_t nrows_y = src0->ne [1 ];
2884-
2885- float scale = 1 .0f ;
2886- float max_bias = 0 .0f ;
2887-
2888- memcpy (&scale, dst->op_params + 0 , sizeof (float ));
2889- memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
2890-
2891- soft_max_f32_sycl (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00,
2892- nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
2893- }
28942643
28952644inline void ggml_sycl_op_scale (ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
28962645 ggml_tensor *dst, const float *src0_dd,
0 commit comments