Skip to content

Commit 11fe7ca

Browse files
committed
CPY: move to a separate file
1 parent 387c5d7 commit 11fe7ca

File tree

4 files changed

+401
-460
lines changed

4 files changed

+401
-460
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "binbcast.hpp"
3333
#include "argmax.hpp"
3434
#include "argsort.hpp"
35+
#include "cpy.hpp"
3536
#include "gla.hpp"
3637

3738
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/cpy.cpp

Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
#include "cpy.hpp"
2+
3+
static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
4+
const float * xi = (const float *) cxi;
5+
float * dsti = (float *) cdsti;
6+
7+
*dsti = *xi;
8+
}
9+
10+
static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
11+
const float * xi = (const float *) cxi;
12+
sycl::half * dsti = (sycl::half *) cdsti;
13+
14+
*dsti = sycl::vec<float, 1>(*xi).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
15+
}
16+
17+
static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
18+
const sycl::half * xi = (const sycl::half *) cxi;
19+
sycl::half * dsti = (sycl::half *) cdsti;
20+
21+
*dsti = *xi;
22+
}
23+
24+
static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
25+
const sycl::half * xi = (const sycl::half *) cxi;
26+
float * dsti = (float *) cdsti;
27+
28+
*dsti = *xi;
29+
}
30+
31+
static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
32+
const int16_t * xi = (const int16_t *) cxi;
33+
int16_t * dsti = (int16_t *) cdsti;
34+
35+
*dsti = *xi;
36+
}
37+
38+
static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
39+
const int32_t * xi = (const int32_t *) cxi;
40+
int32_t * dsti = (int32_t *) cdsti;
41+
42+
*dsti = *xi;
43+
}
44+
45+
template <cpy_kernel_t cpy_1>
46+
static void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
47+
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
48+
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
49+
const sycl::nd_item<3> & item_ct1) {
50+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
51+
52+
if (i >= ne) {
53+
return;
54+
}
55+
56+
// determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
57+
// then combine those indices with the corresponding byte offsets to get the total offsets
58+
const int i03 = i / (ne00 * ne01 * ne02);
59+
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
60+
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
61+
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
62+
const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
63+
64+
const int i13 = i / (ne10 * ne11 * ne12);
65+
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
66+
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
67+
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
68+
const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
69+
70+
cpy_1(cx + x_offset, cdst + dst_offset);
71+
}
72+
73+
static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
74+
const float * xi = (const float *) cxi;
75+
block_q8_0 * dsti = (block_q8_0 *) cdsti;
76+
77+
float amax = 0.0f; // absolute max
78+
79+
for (int j = 0; j < QK8_0; j++) {
80+
const float v = xi[j];
81+
amax = sycl::fmax(amax, sycl::fabs((float) v));
82+
}
83+
84+
const float d = amax / ((1 << 7) - 1);
85+
const float id = d ? 1.0f / d : 0.0f;
86+
87+
dsti->d = d;
88+
89+
for (int j = 0; j < QK8_0; ++j) {
90+
const float x0 = xi[j] * id;
91+
92+
dsti->qs[j] = sycl::round((float) x0);
93+
}
94+
}
95+
96+
static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
97+
const float * xi = (const float *) cxi;
98+
block_q4_0 * dsti = (block_q4_0 *) cdsti;
99+
100+
float amax = 0.0f;
101+
float vmax = 0.0f;
102+
103+
for (int j = 0; j < QK4_0; ++j) {
104+
const float v = xi[j];
105+
if (amax < sycl::fabs((float) v)) {
106+
amax = sycl::fabs((float) v);
107+
vmax = v;
108+
}
109+
}
110+
111+
const float d = vmax / -8;
112+
const float id = d ? 1.0f / d : 0.0f;
113+
114+
dsti->d = d;
115+
116+
for (int j = 0; j < QK4_0 / 2; ++j) {
117+
const float x0 = xi[0 + j] * id;
118+
const float x1 = xi[QK4_0 / 2 + j] * id;
119+
120+
const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 8.5f));
121+
const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 8.5f));
122+
123+
dsti->qs[j] = xi0;
124+
dsti->qs[j] |= xi1 << 4;
125+
}
126+
}
127+
128+
static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
129+
const float * xi = (const float *) cxi;
130+
block_q4_1 * dsti = (block_q4_1 *) cdsti;
131+
132+
float vmin = FLT_MAX;
133+
float vmax = -FLT_MAX;
134+
135+
for (int j = 0; j < QK4_1; ++j) {
136+
const float v = xi[j];
137+
138+
if (v < vmin) {
139+
vmin = v;
140+
}
141+
if (v > vmax) {
142+
vmax = v;
143+
}
144+
}
145+
146+
const float d = (vmax - vmin) / ((1 << 4) - 1);
147+
const float id = d ? 1.0f / d : 0.0f;
148+
149+
dsti->dm.x() = d;
150+
dsti->dm.y() = vmin;
151+
152+
for (int j = 0; j < QK4_1 / 2; ++j) {
153+
const float x0 = (xi[0 + j] - vmin) * id;
154+
const float x1 = (xi[QK4_1 / 2 + j] - vmin) * id;
155+
156+
const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 0.5f));
157+
const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 0.5f));
158+
159+
dsti->qs[j] = xi0;
160+
dsti->qs[j] |= xi1 << 4;
161+
}
162+
}
163+
164+
template <cpy_kernel_t cpy_blck, int qk>
165+
static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
166+
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
167+
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
168+
const sycl::nd_item<3> & item_ct1) {
169+
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
170+
171+
if (i >= ne) {
172+
return;
173+
}
174+
175+
const int i03 = i / (ne00 * ne01 * ne02);
176+
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
177+
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
178+
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
179+
const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
180+
181+
const int i13 = i / (ne10 * ne11 * ne12);
182+
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
183+
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
184+
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
185+
const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
186+
187+
cpy_blck(cx + x_offset, cdst + dst_offset);
188+
}
189+
190+
static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
191+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
192+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
193+
const int nb12, const int nb13, queue_ptr stream) {
194+
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
195+
{
196+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
197+
198+
stream->parallel_for(
199+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
200+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
201+
[=](sycl::nd_item<3> item_ct1) {
202+
cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
203+
nb10, nb11, nb12, nb13, item_ct1);
204+
});
205+
}
206+
}
207+
208+
static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
209+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
210+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
211+
const int nb12, const int nb13, queue_ptr stream) {
212+
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
213+
{
214+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
215+
216+
stream->parallel_for(
217+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
218+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
219+
[=](sycl::nd_item<3> item_ct1) {
220+
cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
221+
nb10, nb11, nb12, nb13, item_ct1);
222+
});
223+
}
224+
}
225+
226+
static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
227+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
228+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
229+
const int nb12, const int nb13, queue_ptr stream) {
230+
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
231+
{
232+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
233+
234+
stream->parallel_for(
235+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
236+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
237+
[=](sycl::nd_item<3> item_ct1) {
238+
cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
239+
nb10, nb11, nb12, nb13, item_ct1);
240+
});
241+
}
242+
}
243+
244+
static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
245+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
246+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
247+
const int nb12, const int nb13, queue_ptr stream) {
248+
GGML_ASSERT(ne % QK8_0 == 0);
249+
const int num_blocks = ne / QK8_0;
250+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
251+
[=](sycl::nd_item<3> item_ct1) {
252+
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
253+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
254+
});
255+
}
256+
257+
static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
258+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
259+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
260+
const int nb12, const int nb13, queue_ptr stream) {
261+
GGML_ASSERT(ne % QK4_0 == 0);
262+
const int num_blocks = ne / QK4_0;
263+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
264+
[=](sycl::nd_item<3> item_ct1) {
265+
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
266+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
267+
});
268+
}
269+
270+
static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
271+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
272+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
273+
const int nb12, const int nb13, queue_ptr stream) {
274+
GGML_ASSERT(ne % QK4_1 == 0);
275+
const int num_blocks = ne / QK4_1;
276+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
277+
[=](sycl::nd_item<3> item_ct1) {
278+
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
279+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
280+
});
281+
}
282+
283+
static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
284+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
285+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
286+
const int nb12, const int nb13, queue_ptr stream) {
287+
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
288+
{
289+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
290+
291+
stream->parallel_for(
292+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
293+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
294+
[=](sycl::nd_item<3> item_ct1) {
295+
cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
296+
nb10, nb11, nb12, nb13, item_ct1);
297+
});
298+
}
299+
}
300+
301+
static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
302+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
303+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
304+
const int nb12, const int nb13, queue_ptr stream) {
305+
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
306+
{
307+
// dpct::has_capability_or_fail(stream->get_device(),
308+
// {sycl::aspect::fp16});
309+
310+
stream->parallel_for(
311+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
312+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
313+
[=](sycl::nd_item<3> item_ct1) {
314+
cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
315+
nb10, nb11, nb12, nb13, item_ct1);
316+
});
317+
}
318+
}
319+
320+
static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
321+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
322+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
323+
const int nb12, const int nb13, queue_ptr stream) {
324+
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
325+
{
326+
// dpct::has_capability_or_fail(stream->get_device(),
327+
// {sycl::aspect::fp16});
328+
329+
stream->parallel_for(
330+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
331+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
332+
[=](sycl::nd_item<3> item_ct1) {
333+
cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
334+
nb10, nb11, nb12, nb13, item_ct1);
335+
});
336+
}
337+
}
338+
339+
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
340+
const int64_t ne = ggml_nelements(src0);
341+
GGML_ASSERT(ne == ggml_nelements(src1));
342+
343+
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
344+
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
345+
346+
GGML_SYCL_TENSOR_BINARY_OP_CP_LOCALS;
347+
348+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
349+
queue_ptr main_stream = ctx.stream();
350+
351+
char * src0_ddc = (char *) src0->data;
352+
char * src1_ddc = (char *) src1->data;
353+
354+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
355+
ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
356+
nb11, nb12, nb13, main_stream);
357+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
358+
ggml_cpy_f32_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
359+
nb11, nb12, nb13, main_stream);
360+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
361+
ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
362+
nb11, nb12, nb13, main_stream);
363+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
364+
ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
365+
nb11, nb12, nb13, main_stream);
366+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
367+
ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
368+
nb11, nb12, nb13, main_stream);
369+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
370+
ggml_cpy_f16_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
371+
nb11, nb12, nb13, main_stream);
372+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
373+
ggml_cpy_f16_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
374+
nb11, nb12, nb13, main_stream);
375+
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
376+
ggml_cpy_i16_i16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
377+
nb11, nb12, nb13, main_stream);
378+
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
379+
ggml_cpy_i32_i32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
380+
nb11, nb12, nb13, main_stream);
381+
} else {
382+
GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
383+
ggml_type_name(src1->type));
384+
GGML_ABORT("fatal error");
385+
}
386+
} catch (const sycl::exception & exc) {
387+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
388+
std::exit(1);
389+
}

0 commit comments

Comments
 (0)