Skip to content

Commit 704d90c

Browse files
Revert "sycl: add usage of enqueue_functions extension (ggml-org#14244)" (ggml-org#15910)
* Revert "sycl: add usage of enqueue_functions extension (ggml-org#14244)" This reverts commit 8308f98. * fix missed revert code, format the code
1 parent 360d653 commit 704d90c

File tree

20 files changed

+844
-673
lines changed

20 files changed

+844
-673
lines changed

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ struct bin_bcast_sycl {
225225
dpct::has_capability_or_fail(stream->get_device(),
226226
{sycl::aspect::fp16});
227227

228-
sycl_parallel_for(
229-
stream,
230-
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
228+
stream->parallel_for(
229+
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230+
sycl::range<3>(1, 1, block_size),
231231
sycl::range<3>(1, 1, block_size)),
232232
[=](sycl::nd_item<3> item_ct1) {
233233
k_bin_bcast_unravel<bin_op>(
@@ -246,8 +246,9 @@ struct bin_bcast_sycl {
246246
dpct::has_capability_or_fail(stream->get_device(),
247247
{sycl::aspect::fp16});
248248

249-
sycl_parallel_for(
250-
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
249+
stream->parallel_for(
250+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251+
[=](sycl::nd_item<3> item_ct1) {
251252
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
252253
ne2, ne3, ne10, ne11, ne12, ne13,
253254
s1, s2, s3, s01, s02, s03, s11, s12, s13,

ggml/src/ggml-sycl/concat.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,24 +89,33 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
8989
sycl::range<3> gridDim(ne2, ne1, num_blocks);
9090
switch (dim) {
9191
case 0:
92-
sycl_parallel_for(stream,
93-
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
94-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
95-
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
96-
break;
92+
stream->parallel_for(
93+
sycl::nd_range<3>(gridDim *
94+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
95+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
96+
[=](sycl::nd_item<3> item_ct1) {
97+
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
98+
});
99+
break;
97100
case 1:
98-
sycl_parallel_for(stream,
99-
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
100-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
101-
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
102-
break;
101+
stream->parallel_for(
102+
sycl::nd_range<3>(gridDim *
103+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
105+
[=](sycl::nd_item<3> item_ct1) {
106+
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107+
});
108+
break;
103109
// dim >=2 will be dispatched to the default path
104110
default:
105-
sycl_parallel_for(stream,
106-
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
107-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
108-
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
109-
break;
111+
stream->parallel_for(
112+
sycl::nd_range<3>(gridDim *
113+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
115+
[=](sycl::nd_item<3> item_ct1) {
116+
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
117+
});
118+
break;
110119
}
111120
}
112121

@@ -120,7 +129,7 @@ static void concat_f32_sycl_non_cont(
120129
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
121130
uint64_t nb3, int32_t dim) {
122131
sycl::range<3> gridDim(ne3, ne2, ne1);
123-
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
132+
stream->parallel_for(sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
124133
int64_t i3 = item_ct1.get_group(0);
125134
int64_t i2 = item_ct1.get_group(1);
126135
int64_t i1 = item_ct1.get_group(2);

ggml/src/ggml-sycl/conv.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,16 @@ static void conv_transpose_1d_f32_f32_sycl(
5959
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
6060
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
6161
const sycl::range<3> block_nums(1, 1, num_blocks);
62-
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
63-
conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
64-
item_ct1);
65-
});
62+
stream->parallel_for(
63+
sycl::nd_range<3>(
64+
block_nums * block_dims, block_dims),
65+
[=](sycl::nd_item<3> item_ct1) {
66+
conv_transpose_1d_kernel(
67+
s0, output_size,
68+
src0_ne0, src0_ne1, src0_ne2,
69+
src1_ne0, dst_ne0,
70+
src0, src1, dst, item_ct1);
71+
});
6672
}
6773

6874
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {

0 commit comments

Comments
 (0)