Skip to content

Commit e72a7ca

Browse files
committed
remove templates from soft_max_f32_submitter to allow SYCL graph updates
1 parent 9ecf3e6 commit e72a7ca

File tree

1 file changed

+29
-84
lines changed

1 file changed

+29
-84
lines changed

ggml/src/ggml-sycl/softmax.cpp

Lines changed: 29 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
#include "softmax.hpp"
22

3-
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
4-
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
3+
template <typename T>
4+
static void soft_max_f32(const float * x, const T * mask, float * dst, int ncols,
55
const int nrows_y, const float scale, const float max_bias, const float m0,
6-
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
7-
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
6+
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf,
7+
const bool vals_smem, const bool check_columns_count) {
88

99
const int tid = item_ct1.get_local_id(2);
1010
const int rowx = item_ct1.get_group(2);
1111
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
1212

13-
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
13+
const int block_size = check_columns_count ? item_ct1.get_local_range(2) : std::min(ncols, 1024);
1414

1515
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
1616
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
@@ -35,7 +35,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
3535
for (int col0 = 0; col0 < ncols; col0 += block_size) {
3636
const int col = col0 + tid;
3737

38-
if (ncols_template == 0 && col >= ncols) {
38+
if (check_columns_count && col >= ncols) {
3939
break;
4040
}
4141

@@ -74,7 +74,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
7474
#pragma unroll
7575
for (int col0 = 0; col0 < ncols; col0 += block_size) {
7676
const int col = col0 + tid;
77-
if (ncols_template == 0 && col >= ncols) {
77+
if (check_columns_count && col >= ncols) {
7878
break;
7979
}
8080

@@ -113,7 +113,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
113113
for (int col0 = 0; col0 < ncols; col0 += block_size) {
114114
const int col = col0 + tid;
115115

116-
if (ncols_template == 0 && col >= ncols) {
116+
if (check_columns_count && col >= ncols) {
117117
return;
118118
}
119119

@@ -122,25 +122,6 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
122122
}
123123
}
124124

125-
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
126-
static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
127-
const int nrows_y, const float scale, const float max_bias, const float m0,
128-
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129-
const size_t n_local_scratch, queue_ptr stream) {
130-
stream->submit([&](sycl::handler &cgh) {
131-
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
132-
133-
cgh.parallel_for(
134-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
135-
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
136-
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
137-
nrows_y, scale, max_bias, m0,
138-
m1, n_head_log2, item_ct1,
139-
get_pointer(local_buf_acc));
140-
});
141-
});
142-
}
143-
144125
template<typename T>
145126
static void soft_max_f32_sycl(const float * x, const T * mask,
146127
float * dst, const int ncols_x, const int nrows_x,
@@ -163,64 +144,28 @@ static void soft_max_f32_sycl(const float * x, const T * mask,
163144
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
164145

165146
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
166-
if (n_local_scratch*sizeof(float) < local_mem_size) {
167-
if (ncols_x > max_block_size) {
168-
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
169-
max_bias, m0, m1, n_head_log2, block_nums,
170-
block_dims, n_local_scratch, stream);
171-
return;
172-
}
173-
switch (ncols_x) {
174-
case 32:
175-
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
176-
max_bias, m0, m1, n_head_log2, block_nums,
177-
block_dims, n_local_scratch, stream);
178-
break;
179-
case 64:
180-
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
181-
max_bias, m0, m1, n_head_log2, block_nums,
182-
block_dims, n_local_scratch, stream);
183-
break;
184-
case 128:
185-
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
186-
max_bias, m0, m1, n_head_log2, block_nums,
187-
block_dims, n_local_scratch, stream);
188-
break;
189-
case 256:
190-
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
191-
max_bias, m0, m1, n_head_log2, block_nums,
192-
block_dims, n_local_scratch, stream);
193-
break;
194-
case 512:
195-
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
196-
max_bias, m0, m1, n_head_log2, block_nums,
197-
block_dims, n_local_scratch, stream);
198-
break;
199-
case 1024:
200-
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
201-
max_bias, m0, m1, n_head_log2, block_nums,
202-
block_dims, n_local_scratch, stream);
203-
break;
204-
case 2048:
205-
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
206-
max_bias, m0, m1, n_head_log2, block_nums,
207-
block_dims, n_local_scratch, stream);
208-
break;
209-
case 4096:
210-
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
211-
max_bias, m0, m1, n_head_log2, block_nums,
212-
block_dims, n_local_scratch, stream);
213-
break;
214-
default:
215-
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
216-
max_bias, m0, m1, n_head_log2, block_nums,
217-
block_dims, n_local_scratch, stream);
218-
break;
219-
}
147+
148+
auto soft_max_f32_submit = [=](size_t scratch_size, bool vals_smem, bool check_columns_count) {
149+
stream->submit([=](sycl::handler &cgh) {
150+
sycl::local_accessor<float, 1> local_buf_acc(scratch_size, cgh);
151+
cgh.parallel_for(
152+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
153+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
154+
soft_max_f32(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2,
155+
item_ct1, get_pointer(local_buf_acc), vals_smem, check_columns_count);
156+
});
157+
});
158+
};
159+
160+
if (n_local_scratch*sizeof(float) >= local_mem_size) {
161+
soft_max_f32_submit(WARP_SIZE, false, true);
162+
} else if (ncols_x > max_block_size) {
163+
soft_max_f32_submit(n_local_scratch, true, true);
164+
} else if (ncols_x == 32 || ncols_x == 64 || ncols_x == 128 || ncols_x == 256
165+
|| ncols_x == 512 || ncols_x == 1024 || ncols_x == 2048 || ncols_x == 4096) {
166+
soft_max_f32_submit(n_local_scratch, true, false);
220167
} else {
221-
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
222-
max_bias, m0, m1, n_head_log2, block_nums,
223-
block_dims, WARP_SIZE, stream);
168+
soft_max_f32_submit(n_local_scratch, true, true);
224169
}
225170
}
226171

0 commit comments

Comments
 (0)