Skip to content

Commit 22e3df0

Browse files
qnixsynapseggerganov
authored andcommitted
SYCL : SOFTMAX F16 mask support and other fixes (llama/11261)
Implemented ggml_sycl_op_soft_max() F16 src1(mask) support for which a pragma deprecation warning was added during #5021. To do this, had to decouple it from ggml_sycl_op_flatten which always considered src1 to be of fp32 type(many OP functions are dependent on it). * SYCL: SOFTMAX F16 mask support and other fixes * test-backend-ops: Add F16 mask test cases
1 parent 028511d commit 22e3df0

File tree

3 files changed

+35
-33
lines changed

3 files changed

+35
-33
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
38783878
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
38793879
}
38803880

3881-
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3882-
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
3883-
}
3884-
38853881
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
38863882
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
38873883
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
@@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
40904086
ggml_sycl_diag_mask_inf(ctx, dst);
40914087
break;
40924088
case GGML_OP_SOFT_MAX:
4093-
ggml_sycl_soft_max(ctx, dst);
4089+
ggml_sycl_op_soft_max(ctx, dst);
40944090
break;
40954091
case GGML_OP_ROPE:
40964092
ggml_sycl_rope(ctx, dst);

ggml/src/ggml-sycl/softmax.cpp

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include "norm.hpp"
1+
#include "softmax.hpp"
22

3-
template <bool vals_smem, int ncols_template, int block_size_template>
4-
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
3+
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
4+
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
55
const int nrows_y, const float scale, const float max_bias, const float m0,
66
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
77
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
2929
slope = sycl::pow(base, float(exp));
3030
}
3131

32-
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
32+
float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
3333
float max_val = -INFINITY;
3434

3535
for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
4242
const int ix = rowx*ncols + col;
4343
const int iy = rowy*ncols + col;
4444

45-
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
45+
const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
4646

4747
vals[col] = val;
4848
max_val = sycl::max(max_val, val);
@@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
6565
item_ct1.barrier(sycl::access::fence_space::local_space);
6666
max_val = buf[lane_id];
6767
for (size_t i = 1; i < nreduce; i += 1) {
68-
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
68+
max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
6969
}
7070
max_val = warp_reduce_max(max_val, item_ct1);
7171
}
@@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
122122
}
123123
}
124124

125-
template <bool vals_smem, int ncols_template, int block_size_template>
126-
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
125+
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
126+
static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
127127
const int nrows_y, const float scale, const float max_bias, const float m0,
128128
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129129
const size_t n_local_scratch, queue_ptr stream) {
@@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
141141
});
142142
}
143143

144-
static void soft_max_f32_sycl(const float * x, const float * mask,
144+
template<typename T>
145+
static void soft_max_f32_sycl(const float * x, const T * mask,
145146
float * dst, const int ncols_x, const int nrows_x,
146147
const int nrows_y, const float scale, const float max_bias,
147148
queue_ptr stream, int device) {
@@ -223,29 +224,38 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
223224
}
224225
}
225226

226-
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227-
const ggml_tensor *src1, ggml_tensor *dst,
228-
const float *src0_dd, const float *src1_dd,
229-
float *dst_dd,
230-
const queue_ptr &main_stream) {
227+
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
231228

232-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
229+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
233230
GGML_ASSERT( dst->type == GGML_TYPE_F32);
234231

235-
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236-
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
237-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
232+
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238233

239-
const int64_t ne00 = src0->ne[0];
240-
const int64_t nrows_x = ggml_nrows(src0);
241-
const int64_t nrows_y = src0->ne[1];
234+
const int64_t ne00 = dst->src[0]->ne[0];
235+
const int64_t nrows_x = ggml_nrows(dst->src[0]);
236+
const int64_t nrows_y = dst->src[0]->ne[1];
242237

243238
float scale = 1.0f;
244239
float max_bias = 0.0f;
245240

246241
memcpy(&scale, dst->op_params + 0, sizeof(float));
247242
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
248243

249-
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
250-
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
244+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
245+
float * dst_dd = static_cast<float *>(dst->data);
246+
247+
ggml_sycl_set_device(ctx.device);
248+
dpct::queue_ptr main_stream = ctx.stream();
249+
250+
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
251+
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
252+
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
253+
main_stream, ctx.device);
254+
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
255+
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
256+
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
257+
} else {
258+
/* mask unavailable */
259+
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
260+
}
251261
}

ggml/src/ggml-sycl/softmax.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515

1616
#include "common.hpp"
1717

18-
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
19-
const ggml_tensor *src1, ggml_tensor *dst,
20-
const float *src0_dd, const float *src1_dd,
21-
float *dst_dd,
22-
const queue_ptr &main_stream);
18+
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
2319

2420
#endif // GGML_SYCL_SOFTMAX_HPP

0 commit comments

Comments
 (0)