1
- #include " norm .hpp"
1
+ #include " softmax .hpp"
2
2
3
- template <bool vals_smem, int ncols_template, int block_size_template>
4
- static void soft_max_f32 (const float * x, const float * mask, float * dst, const int ncols_par,
3
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T >
4
+ static void soft_max_f32 (const float * x, const T * mask, float * dst, const int ncols_par,
5
5
const int nrows_y, const float scale, const float max_bias, const float m0,
6
6
const float m1, uint32_t n_head_log2, const sycl::nd_item<3 > &item_ct1, float *buf) {
7
7
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
29
29
slope = sycl::pow (base, float (exp));
30
30
}
31
31
32
- float *vals = vals_smem ? buf + std ::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
32
+ float *vals = vals_smem ? buf + sycl ::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
33
33
float max_val = -INFINITY;
34
34
35
35
for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
@@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
42
42
const int ix = rowx*ncols + col;
43
43
const int iy = rowy*ncols + col;
44
44
45
- const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0 .0f );
45
+ const float val = x[ix]*scale + (mask ? slope*static_cast < float >( mask[iy]) : 0 .0f );
46
46
47
47
vals[col] = val;
48
48
max_val = sycl::max (max_val, val);
@@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
65
65
item_ct1.barrier (sycl::access::fence_space::local_space);
66
66
max_val = buf[lane_id];
67
67
for (size_t i = 1 ; i < nreduce; i += 1 ) {
68
- max_val = std ::max (max_val, buf[lane_id + i * WARP_SIZE]);
68
+ max_val = sycl ::max (max_val, buf[lane_id + i * WARP_SIZE]);
69
69
}
70
70
max_val = warp_reduce_max (max_val, item_ct1);
71
71
}
@@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
122
122
}
123
123
}
124
124
125
- template <bool vals_smem, int ncols_template, int block_size_template>
126
- static void soft_max_f32_submitter (const float * x, const float * mask, float * dst, const int ncols_par,
125
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T >
126
+ static void soft_max_f32_submitter (const float * x, const T * mask, float * dst, const int ncols_par,
127
127
const int nrows_y, const float scale, const float max_bias, const float m0,
128
128
const float m1, uint32_t n_head_log2, sycl::range<3 > block_nums, sycl::range<3 > block_dims,
129
129
const size_t n_local_scratch, queue_ptr stream) {
@@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
141
141
});
142
142
}
143
143
144
- static void soft_max_f32_sycl (const float * x, const float * mask,
144
+ template <typename T>
145
+ static void soft_max_f32_sycl (const float * x, const T * mask,
145
146
float * dst, const int ncols_x, const int nrows_x,
146
147
const int nrows_y, const float scale, const float max_bias,
147
148
queue_ptr stream, int device) {
@@ -223,29 +224,38 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
223
224
}
224
225
}
225
226
226
- void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227
- const ggml_tensor *src1, ggml_tensor *dst,
228
- const float *src0_dd, const float *src1_dd,
229
- float *dst_dd,
230
- const queue_ptr &main_stream) {
227
+ void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
231
228
232
- GGML_ASSERT (src0 ->type == GGML_TYPE_F32);
229
+ GGML_ASSERT (dst-> src [ 0 ] ->type == GGML_TYPE_F32);
233
230
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
234
231
235
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236
- #pragma message("ref: https:// github.com/ggerganov/llama.cpp/pull/5021")
237
- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
232
+ GGML_ASSERT (!dst->src [1 ] || dst->src [1 ]->type == GGML_TYPE_F16 || dst->src [1 ]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238
233
239
- const int64_t ne00 = src0 ->ne [0 ];
240
- const int64_t nrows_x = ggml_nrows (src0 );
241
- const int64_t nrows_y = src0 ->ne [1 ];
234
+ const int64_t ne00 = dst-> src [ 0 ] ->ne [0 ];
235
+ const int64_t nrows_x = ggml_nrows (dst-> src [ 0 ] );
236
+ const int64_t nrows_y = dst-> src [ 0 ] ->ne [1 ];
242
237
243
238
float scale = 1 .0f ;
244
239
float max_bias = 0 .0f ;
245
240
246
241
memcpy (&scale, dst->op_params + 0 , sizeof (float ));
247
242
memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
248
243
249
- soft_max_f32_sycl (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00,
250
- nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
244
+ const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
245
+ float * dst_dd = static_cast <float *>(dst->data );
246
+
247
+ ggml_sycl_set_device (ctx.device );
248
+ dpct::queue_ptr main_stream = ctx.stream ();
249
+
250
+ if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F16) {
251
+ const sycl::half * src1_dd = static_cast <sycl::half *>(dst->src [1 ]->data );
252
+ soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
253
+ main_stream, ctx.device );
254
+ } else if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F32) {
255
+ const float * src1_dd = static_cast <const float *>(dst->src [1 ]->data );
256
+ soft_max_f32_sycl<float >(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
257
+ } else {
258
+ /* mask unavailable */
259
+ soft_max_f32_sycl<float >(src0_dd, nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
260
+ }
251
261
}
0 commit comments