1- #include " norm .hpp"
1+ #include " softmax .hpp"
22
3- template <bool vals_smem, int ncols_template, int block_size_template>
4- static void soft_max_f32 (const float * x, const float * mask, float * dst, const int ncols_par,
3+ template <typename T> static inline float t2f32 (T val) {
4+ return static_cast <float >(val);
5+ }
6+
7+ template <> inline float t2f32<sycl::half>(sycl::half val) {
8+ return static_cast <float >(val);
9+ }
10+
11+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
12+ static void soft_max_f32 (const float * x, const T * mask, float * dst, const int ncols_par,
513 const int nrows_y, const float scale, const float max_bias, const float m0,
614 const float m1, uint32_t n_head_log2, const sycl::nd_item<3 > &item_ct1, float *buf) {
715 const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,9 +37,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
2937 slope = sycl::pow (base, float (exp));
3038 }
3139
32- float *vals = vals_smem ? buf + std ::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
40+ float *vals = vals_smem ? buf + sycl ::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
3341 float max_val = -INFINITY;
3442
43+ #pragma unroll
3544 for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
3645 const int col = col0 + tid;
3746
@@ -42,7 +51,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
4251 const int ix = rowx*ncols + col;
4352 const int iy = rowy*ncols + col;
4453
45- const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0 .0f );
54+ const float val = x[ix]*scale + (mask ? slope*t2f32 ( mask[iy]) : 0 .0f );
4655
4756 vals[col] = val;
4857 max_val = sycl::max (max_val, val);
@@ -65,7 +74,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
6574 item_ct1.barrier (sycl::access::fence_space::local_space);
6675 max_val = buf[lane_id];
6776 for (size_t i = 1 ; i < nreduce; i += 1 ) {
68- max_val = std ::max (max_val, buf[lane_id + i * WARP_SIZE]);
77+ max_val = sycl ::max (max_val, buf[lane_id + i * WARP_SIZE]);
6978 }
7079 max_val = warp_reduce_max (max_val, item_ct1);
7180 }
@@ -122,8 +131,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
122131 }
123132}
124133
125- template <bool vals_smem, int ncols_template, int block_size_template>
126- static void soft_max_f32_submitter (const float * x, const float * mask, float * dst, const int ncols_par,
134+ template <bool vals_smem, int ncols_template, int block_size_template, typename T >
135+ static void soft_max_f32_submitter (const float * x, const T * mask, float * dst, const int ncols_par,
127136 const int nrows_y, const float scale, const float max_bias, const float m0,
128137 const float m1, uint32_t n_head_log2, sycl::range<3 > block_nums, sycl::range<3 > block_dims,
129138 const size_t n_local_scratch, queue_ptr stream) {
@@ -133,15 +142,16 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
133142 cgh.parallel_for (
134143 sycl::nd_range<3 >(block_nums * block_dims, block_dims),
135144 [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE)]] {
136- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
145+ soft_max_f32<vals_smem, ncols_template, block_size_template, T >(x, mask, dst, ncols_par,
137146 nrows_y, scale, max_bias, m0,
138147 m1, n_head_log2, item_ct1,
139148 get_pointer (local_buf_acc));
140149 });
141150 });
142151}
143152
144- static void soft_max_f32_sycl (const float * x, const float * mask,
153+ template <typename T>
154+ static void soft_max_f32_sycl (const float * x, const T * mask,
145155 float * dst, const int ncols_x, const int nrows_x,
146156 const int nrows_y, const float scale, const float max_bias,
147157 queue_ptr stream, int device) {
@@ -164,88 +174,99 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
164174 const size_t local_mem_size = stream->get_device ().get_info <sycl::info::device::local_mem_size>();
165175 if (n_local_scratch*sizeof (float ) < local_mem_size) {
166176 if (ncols_x > max_block_size) {
167- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
177+ soft_max_f32_submitter<true , 0 , 0 , T >(x, mask, dst, ncols_x, nrows_y, scale,
168178 max_bias, m0, m1, n_head_log2, block_nums,
169179 block_dims, n_local_scratch, stream);
170180 return ;
171181 }
172182 switch (ncols_x) {
173183 case 32 :
174- soft_max_f32_submitter<true , 32 , 32 >(x, mask, dst, ncols_x, nrows_y, scale,
184+ soft_max_f32_submitter<true , 32 , 32 , T >(x, mask, dst, ncols_x, nrows_y, scale,
175185 max_bias, m0, m1, n_head_log2, block_nums,
176186 block_dims, n_local_scratch, stream);
177187 break ;
178188 case 64 :
179- soft_max_f32_submitter<true , 64 , 64 >(x, mask, dst, ncols_x, nrows_y, scale,
189+ soft_max_f32_submitter<true , 64 , 64 , T >(x, mask, dst, ncols_x, nrows_y, scale,
180190 max_bias, m0, m1, n_head_log2, block_nums,
181191 block_dims, n_local_scratch, stream);
182192 break ;
183193 case 128 :
184- soft_max_f32_submitter<true , 128 , 128 >(x, mask, dst, ncols_x, nrows_y, scale,
194+ soft_max_f32_submitter<true , 128 , 128 , T >(x, mask, dst, ncols_x, nrows_y, scale,
185195 max_bias, m0, m1, n_head_log2, block_nums,
186196 block_dims, n_local_scratch, stream);
187197 break ;
188198 case 256 :
189- soft_max_f32_submitter<true , 256 , 256 >(x, mask, dst, ncols_x, nrows_y, scale,
199+ soft_max_f32_submitter<true , 256 , 256 , T >(x, mask, dst, ncols_x, nrows_y, scale,
190200 max_bias, m0, m1, n_head_log2, block_nums,
191201 block_dims, n_local_scratch, stream);
192202 break ;
193203 case 512 :
194- soft_max_f32_submitter<true , 512 , 512 >(x, mask, dst, ncols_x, nrows_y, scale,
204+ soft_max_f32_submitter<true , 512 , 512 , T >(x, mask, dst, ncols_x, nrows_y, scale,
195205 max_bias, m0, m1, n_head_log2, block_nums,
196206 block_dims, n_local_scratch, stream);
197207 break ;
198208 case 1024 :
199- soft_max_f32_submitter<true , 1024 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
209+ soft_max_f32_submitter<true , 1024 , 1024 , T >(x, mask, dst, ncols_x, nrows_y, scale,
200210 max_bias, m0, m1, n_head_log2, block_nums,
201211 block_dims, n_local_scratch, stream);
202212 break ;
203213 case 2048 :
204- soft_max_f32_submitter<true , 2048 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
214+ soft_max_f32_submitter<true , 2048 , 1024 , T >(x, mask, dst, ncols_x, nrows_y, scale,
205215 max_bias, m0, m1, n_head_log2, block_nums,
206216 block_dims, n_local_scratch, stream);
207217 break ;
208218 case 4096 :
209- soft_max_f32_submitter<true , 4096 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
219+ soft_max_f32_submitter<true , 4096 , 1024 , T >(x, mask, dst, ncols_x, nrows_y, scale,
210220 max_bias, m0, m1, n_head_log2, block_nums,
211221 block_dims, n_local_scratch, stream);
212222 break ;
213223 default :
214- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
224+ soft_max_f32_submitter<true , 0 , 0 , T >(x, mask, dst, ncols_x, nrows_y, scale,
215225 max_bias, m0, m1, n_head_log2, block_nums,
216226 block_dims, n_local_scratch, stream);
217227 break ;
218228 }
219229 } else {
220- soft_max_f32_submitter<false , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
230+ soft_max_f32_submitter<false , 0 , 0 , T >(x, mask, dst, ncols_x, nrows_y, scale,
221231 max_bias, m0, m1, n_head_log2, block_nums,
222232 block_dims, WARP_SIZE, stream);
223233 }
224234}
225235
226- void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227- const ggml_tensor *src1, ggml_tensor *dst,
228- const float *src0_dd, const float *src1_dd,
229- float *dst_dd,
230- const queue_ptr &main_stream) {
236+ void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
231237
232- GGML_ASSERT (src0 ->type == GGML_TYPE_F32);
238+ GGML_ASSERT (dst-> src [ 0 ] ->type == GGML_TYPE_F32);
233239 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
234240
235- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236- #pragma message("ref: https:// github.com/ggerganov/llama.cpp/pull/5021")
237- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
241+ GGML_ASSERT (!dst->src [1 ] || dst->src [1 ]->type == GGML_TYPE_F16 || dst->src [1 ]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238242
239- const int64_t ne00 = src0 ->ne [0 ];
240- const int64_t nrows_x = ggml_nrows (src0 );
241- const int64_t nrows_y = src0 ->ne [1 ];
243+ const int64_t ne00 = dst-> src [ 0 ] ->ne [0 ];
244+ const int64_t nrows_x = ggml_nrows (dst-> src [ 0 ] );
245+ const int64_t nrows_y = dst-> src [ 0 ] ->ne [1 ];
242246
243247 float scale = 1 .0f ;
244248 float max_bias = 0 .0f ;
245249
246250 memcpy (&scale, dst->op_params + 0 , sizeof (float ));
247251 memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
248252
249- soft_max_f32_sycl (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00,
250- nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
253+ const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
254+ float * dst_dd = static_cast <float *>(dst->data );
255+
256+ ggml_sycl_set_device (ctx.device );
257+ dpct::queue_ptr main_stream = ctx.stream ();
258+
259+ if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F16) {
260+ // printf("%s: fp16 mask\n", __func__);
261+ const sycl::half * src1_dd = static_cast <sycl::half *>(dst->src [1 ]->data );
262+ soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
263+ main_stream, ctx.device );
264+ } else if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F32) {
265+ // printf("%s: fp32 mask\n", __func__);
266+ const float * src1_dd = static_cast <const float *>(dst->src [1 ]->data );
267+ soft_max_f32_sycl<float >(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
268+ } else {
269+ /* mask unavailable */
270+ soft_max_f32_sycl<float >(src0_dd, nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
271+ }
251272}
0 commit comments