Skip to content

Commit 857067a

Browse files
committed
SYCL: Initial set_rows kernel implementation
1 parent 6efcd65 commit 857067a

File tree

4 files changed

+157
-1
lines changed

4 files changed

+157
-1
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "outprod.hpp"
3131
#include "quants.hpp"
3232
#include "rope.hpp"
33+
#include "set_rows.hpp"
3334
#include "softmax.hpp"
3435
#include "tsembd.hpp"
3536
#include "wkv.hpp"

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "ggml-sycl/element_wise.hpp"
4242
#include "ggml-sycl/presets.hpp"
4343
#include "ggml-sycl/gemm.hpp"
44+
#include "ggml-sycl/set_rows.hpp"
4445
#include "ggml-sycl/sycl_hw.hpp"
4546
#include "ggml-sycl/getrows.hpp"
4647
#include "ggml.h"
@@ -3603,6 +3604,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36033604
case GGML_OP_GET_ROWS:
36043605
ggml_sycl_get_rows(ctx, dst);
36053606
break;
3607+
case GGML_OP_SET_ROWS:
3608+
ggml_sycl_op_set_rows(ctx, dst);
3609+
break;
36063610
case GGML_OP_DUP:
36073611
ggml_sycl_dup(ctx, dst);
36083612
break;
@@ -4297,7 +4301,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42974301
{
42984302
// TODO: add support
42994303
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
4300-
return false;
4304+
return (op->type == GGML_TYPE_F32 || (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64));
43014305
} break;
43024306
case GGML_OP_CPY:
43034307
{

ggml/src/ggml-sycl/set_rows.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#include "set_rows.hpp"
2+
3+
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
4+
5+
static void set_rows_1_f32_f32(const char * src, char * dst) {
6+
const float * src_f = (const float *) src;
7+
float * dst_f = (float *) dst;
8+
*dst_f = *src_f;
9+
}
10+
11+
static void set_rows_1_f32_f16(const char * src, char * dst) {
12+
const float * src_f = (const float *) src;
13+
sycl::half * dst_h = (sycl::half *) dst;
14+
*dst_h = sycl::vec<float, 1>(*src_f).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
15+
}
16+
17+
template<set_rows_kernel_t set_rows_1>
18+
static void k_set_rows(
19+
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
20+
const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
21+
const size_t nb01, const size_t nb02, const size_t nb03,
22+
const size_t nb10, const size_t nb11, const size_t nb12,
23+
const size_t nb1, const size_t nb2, const size_t nb3,
24+
const size_t src_type_size, const size_t dst_type_size,
25+
const sycl::nd_item<3> & item_ct1) {
26+
27+
const int i03 = item_ct1.get_group(0);
28+
const int i02 = item_ct1.get_group(1);
29+
const int i01 = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); // Row index
30+
31+
if (i01 >= ne01) {
32+
return;
33+
}
34+
35+
const int i12 = i03 % ne12;
36+
const int i11 = i02 % ne11;
37+
const int i10 = i01;
38+
39+
const int64_t dst_row = *(const int64_t *)((const char *)src1 + i10*nb10 + i11*nb11 + i12*nb12);
40+
41+
const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
42+
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
43+
// Optimize for same-type operations: use collective memory copy
44+
if (src_type_size == dst_type_size) {
45+
// All threads in the work-group cooperatively copy the row
46+
const size_t row_bytes = ne00 * src_type_size;
47+
// Each thread copies a chunk of the row
48+
for (size_t byte_idx = item_ct1.get_local_id(0); byte_idx < row_bytes; byte_idx += item_ct1.get_local_range(0)) {
49+
dst_row_ptr[byte_idx] = src0_row[byte_idx];
50+
}
51+
} else {
52+
// Type conversion required, use element-wise approach
53+
for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) {
54+
const char * src_elem = src0_row + col * src_type_size;
55+
char * dst_elem = dst_row_ptr + col * dst_type_size;
56+
set_rows_1(src_elem, dst_elem);
57+
}
58+
}
59+
}
60+
61+
template<set_rows_kernel_t set_rows_1>
62+
static void set_rows_sycl(
63+
const char * src0_d, const int64_t * src1_d, char * dst_d,
64+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
65+
const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03,
66+
const size_t nb10, const size_t nb11, const size_t nb12,
67+
const size_t nb1, const size_t nb2, const size_t nb3,
68+
const size_t src_type_size, const size_t dst_type_size,
69+
queue_ptr stream) {
70+
71+
const int max_threads_per_row = 128; // KEEPING 128 for now
72+
const int threads_per_row = std::min((int)ne00, max_threads_per_row);
73+
74+
const int max_threads_per_block = 128;
75+
const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
76+
77+
const sycl::range<3> block_size(1, rows_per_block, threads_per_row);
78+
const sycl::range<3> grid_size(ne03, ne02, (ne01 + rows_per_block - 1) / rows_per_block);
79+
80+
if (ne01 > 0 && ne00 > 0) {
81+
sycl_parallel_for(
82+
stream,
83+
sycl::nd_range<3>(grid_size * block_size, block_size),
84+
[=](sycl::nd_item<3> item_ct1) {
85+
k_set_rows<set_rows_1>(
86+
src0_d, src1_d, dst_d,
87+
ne00, ne01, ne11, ne12,
88+
nb01, nb02, nb03,
89+
nb10, nb11, nb12,
90+
nb1, nb2, nb3,
91+
src_type_size, dst_type_size,
92+
item_ct1
93+
);
94+
}
95+
);
96+
}
97+
}
98+
99+
100+
void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
101+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
102+
const ggml_tensor * src0 = dst->src[0];
103+
const ggml_tensor * src1 = dst->src[1];
104+
105+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
106+
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64);
107+
108+
GGML_TENSOR_BINARY_OP_LOCALS
109+
110+
const int64_t * src1_dd = static_cast<const int64_t *>(src1->data);
111+
112+
dpct::queue_ptr stream = ctx.stream();
113+
switch (dst->type) {
114+
case GGML_TYPE_F32:
115+
set_rows_sycl<set_rows_1_f32_f32>(
116+
(const char *)dst->src[0]->data, src1_dd, (char *)dst->data,
117+
ne00, ne01, ne02, ne03,
118+
ne11, ne12,
119+
nb01, nb02, nb03,
120+
nb10, nb11, nb12,
121+
nb1, nb2, nb3,
122+
sizeof(float), sizeof(float),
123+
stream
124+
);
125+
break;
126+
case GGML_TYPE_F16:
127+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
128+
set_rows_sycl<set_rows_1_f32_f16>(
129+
(const char *)dst->src[0]->data, src1_dd, (char *)dst->data,
130+
ne00, ne01, ne02, ne03,
131+
ne11, ne12,
132+
nb01, nb02, nb03,
133+
nb10, nb11, nb12,
134+
nb1, nb2, nb3,
135+
sizeof(float), sizeof(sycl::half),
136+
stream
137+
);
138+
break;
139+
default:
140+
GGML_ABORT("Unsupported tensor type!");
141+
break;
142+
}
143+
}

ggml/src/ggml-sycl/set_rows.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_SET_ROWS_HPP
2+
#define GGML_SYCL_SET_ROWS_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_SET_ROWS_HPP

0 commit comments

Comments
 (0)