1- #include " norm .hpp"
1+ #include " softmax .hpp"
22
3- template <bool vals_smem, int ncols_template, int block_size_template>
4- static void soft_max_f32 (const float * x, const float * mask, float * dst, const int ncols_par,
3+ template <bool vals_smem, int ncols_template, int block_size_template, typename T >
4+ static void soft_max_f32 (const float * x, const T * mask, float * dst, const int ncols_par,
55 const int nrows_y, const float scale, const float max_bias, const float m0,
66 const float m1, uint32_t n_head_log2, const sycl::nd_item<3 > &item_ct1, float *buf) {
77 const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
2929 slope = sycl::pow (base, float (exp));
3030 }
3131
32- float *vals = vals_smem ? buf + std ::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
32+ float *vals = vals_smem ? buf + sycl ::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
3333 float max_val = -INFINITY;
3434
3535 for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
@@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
4242 const int ix = rowx*ncols + col;
4343 const int iy = rowy*ncols + col;
4444
45- const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0 .0f );
45+ const float val = x[ix]*scale + (mask ? slope*static_cast < float >( mask[iy]) : 0 .0f );
4646
4747 vals[col] = val;
4848 max_val = sycl::max (max_val, val);
@@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
6565 item_ct1.barrier (sycl::access::fence_space::local_space);
6666 max_val = buf[lane_id];
6767 for (size_t i = 1 ; i < nreduce; i += 1 ) {
68- max_val = std ::max (max_val, buf[lane_id + i * WARP_SIZE]);
68+ max_val = sycl ::max (max_val, buf[lane_id + i * WARP_SIZE]);
6969 }
7070 max_val = warp_reduce_max (max_val, item_ct1);
7171 }
@@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
122122 }
123123}
124124
125- template <bool vals_smem, int ncols_template, int block_size_template>
126- static void soft_max_f32_submitter (const float * x, const float * mask, float * dst, const int ncols_par,
125+ template <bool vals_smem, int ncols_template, int block_size_template, typename T >
126+ static void soft_max_f32_submitter (const float * x, const T * mask, float * dst, const int ncols_par,
127127 const int nrows_y, const float scale, const float max_bias, const float m0,
128128 const float m1, uint32_t n_head_log2, sycl::range<3 > block_nums, sycl::range<3 > block_dims,
129129 const size_t n_local_scratch, queue_ptr stream) {
@@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
141141 });
142142}
143143
144- static void soft_max_f32_sycl (const float * x, const float * mask,
144+ template <typename T>
145+ static void soft_max_f32_sycl (const float * x, const T * mask,
145146 float * dst, const int ncols_x, const int nrows_x,
146147 const int nrows_y, const float scale, const float max_bias,
147148 queue_ptr stream, int device) {
@@ -223,29 +224,38 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
223224 }
224225}
225226
226- void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227- const ggml_tensor *src1, ggml_tensor *dst,
228- const float *src0_dd, const float *src1_dd,
229- float *dst_dd,
230- const queue_ptr &main_stream) {
227+ void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
231228
232- GGML_ASSERT (src0 ->type == GGML_TYPE_F32);
229+ GGML_ASSERT (dst-> src [ 0 ] ->type == GGML_TYPE_F32);
233230 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
234231
235- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236- #pragma message("ref: https:// github.com/ggerganov/llama.cpp/pull/5021")
237- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
232+ GGML_ASSERT (!dst->src [1 ] || dst->src [1 ]->type == GGML_TYPE_F16 || dst->src [1 ]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238233
239- const int64_t ne00 = src0 ->ne [0 ];
240- const int64_t nrows_x = ggml_nrows (src0 );
241- const int64_t nrows_y = src0 ->ne [1 ];
234+ const int64_t ne00 = dst-> src [ 0 ] ->ne [0 ];
235+ const int64_t nrows_x = ggml_nrows (dst-> src [ 0 ] );
236+ const int64_t nrows_y = dst-> src [ 0 ] ->ne [1 ];
242237
243238 float scale = 1 .0f ;
244239 float max_bias = 0 .0f ;
245240
246241 memcpy (&scale, dst->op_params + 0 , sizeof (float ));
247242 memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
248243
249- soft_max_f32_sycl (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00,
250- nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
244+ const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
245+ float * dst_dd = static_cast <float *>(dst->data );
246+
247+ ggml_sycl_set_device (ctx.device );
248+ dpct::queue_ptr main_stream = ctx.stream ();
249+
250+ if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F16) {
251+ const sycl::half * src1_dd = static_cast <sycl::half *>(dst->src [1 ]->data );
252+ soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
253+ main_stream, ctx.device );
254+ } else if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F32) {
255+ const float * src1_dd = static_cast <const float *>(dst->src [1 ]->data );
256+ soft_max_f32_sycl<float >(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
257+ } else {
258+ /* mask unavailable */
259+ soft_max_f32_sycl<float >(src0_dd, nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
260+ }
251261}
0 commit comments