Skip to content

Commit 4da9281

Browse files
committed
Add few missing cpy kernels
1 parent e696d06 commit 4da9281

File tree

3 files changed

+328
-2
lines changed

3 files changed

+328
-2
lines changed

ggml/src/ggml-sycl/cpy.cpp

Lines changed: 303 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
11
#include "cpy.hpp"
2+
23
#include <float.h>
34

5+
#include "dequantize.hpp"
6+
7+
static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
8+
if (x <= val[0]) {
9+
return 0;
10+
}
11+
if (x >= val[n - 1]) {
12+
return n - 1;
13+
}
14+
int ml = 0, mu = n - 1;
15+
while (mu - ml > 1) {
16+
int mav = (ml + mu) / 2;
17+
if (x < val[mav]) {
18+
mu = mav;
19+
} else {
20+
ml = mav;
21+
}
22+
}
23+
return x - val[mu - 1] < val[mu] - x ? mu - 1 : mu;
24+
}
25+
426
static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
527
const float * xi = (const float *) cxi;
628
float * dsti = (float *) cdsti;
@@ -94,6 +116,17 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
94116
}
95117
}
96118

119+
static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
120+
float * cdstf = (float *) (cdsti);
121+
122+
for (int j = 0; j < QK8_0; j += 2) {
123+
dfloat2 dq;
124+
dequantize_q8_0(cxi, 0, j, dq);
125+
*(cdstf + j) = dq.x();
126+
*(cdstf + j + 1) = dq.y();
127+
}
128+
}
129+
97130
static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
98131
const float * xi = (const float *) cxi;
99132
block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -162,6 +195,122 @@ static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
162195
}
163196
}
164197

198+
static void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
199+
const float * xi = (const float *) cxi;
200+
block_q5_0 * dsti = (block_q5_0 *) cdsti;
201+
202+
float amax = 0.0f;
203+
float vmax = 0.0f;
204+
205+
for (int j = 0; j < QK5_0; ++j) {
206+
const float v = xi[j];
207+
if (amax < sycl::fabs((float) v)) {
208+
amax = sycl::fabs((float) v);
209+
vmax = v;
210+
}
211+
}
212+
213+
const float d = vmax / -16;
214+
const float id = d ? 1.0f / d : 0.0f;
215+
216+
dsti->d = d;
217+
218+
uint32_t qh = 0;
219+
for (int j = 0; j < QK5_0 / 2; ++j) {
220+
const float x0 = xi[0 + j] * id;
221+
const float x1 = xi[QK5_0 / 2 + j] * id;
222+
223+
const uint8_t xi0 = dpct::min(31, (int8_t) (x0 + 16.5f));
224+
const uint8_t xi1 = dpct::min(31, (int8_t) (x1 + 16.5f));
225+
226+
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
227+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
228+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0 / 2);
229+
}
230+
memcpy(dsti->qh, &qh, sizeof(qh));
231+
}
232+
233+
static void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
234+
const float * xi = (const float *) cxi;
235+
block_q5_1 * dsti = (block_q5_1 *) cdsti;
236+
237+
float min = xi[0];
238+
float max = xi[0];
239+
240+
for (int j = 1; j < QK5_1; ++j) {
241+
const float v = xi[j];
242+
min = v < min ? v : min;
243+
max = v > max ? v : max;
244+
}
245+
246+
const float d = (max - min) / 31;
247+
const float id = d ? 1.0f / d : 0.0f;
248+
249+
dsti->dm.x() = d;
250+
dsti->dm.y() = min;
251+
252+
uint32_t qh = 0;
253+
for (int j = 0; j < QK5_1 / 2; ++j) {
254+
const float x0 = (xi[0 + j] - min) * id;
255+
const float x1 = (xi[QK5_1 / 2 + j] - min) * id;
256+
257+
const uint8_t xi0 = (uint8_t) (x0 + 0.5f);
258+
const uint8_t xi1 = (uint8_t) (x1 + 0.5f);
259+
260+
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
261+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
262+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1 / 2);
263+
}
264+
memcpy(dsti->qh, &qh, sizeof(qh));
265+
}
266+
267+
static void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
268+
const float * xi = (const float *) cxi;
269+
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
270+
271+
float amax = 0.0f;
272+
float vmax = 0.0f;
273+
274+
for (int j = 0; j < QK4_NL; ++j) {
275+
const float v = xi[j];
276+
if (amax < sycl::fabs((float) v)) {
277+
amax = sycl::fabs((float) v);
278+
vmax = v;
279+
}
280+
}
281+
282+
float d = vmax / kvalues_iq4nl[0];
283+
const float id = d ? 1.0f / d : 0.0f;
284+
285+
float sumqx = 0, sumq2 = 0;
286+
for (int j = 0; j < QK4_NL / 2; ++j) {
287+
const float x0 = xi[0 + j] * id;
288+
const float x1 = xi[QK4_NL / 2 + j] * id;
289+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
290+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
291+
dsti->qs[j] = xi0 | (xi1 << 4);
292+
const float v0 = kvalues_iq4nl[xi0];
293+
const float v1 = kvalues_iq4nl[xi1];
294+
const float w0 = xi[0 + j] * xi[0 + j];
295+
const float w1 = xi[QK4_NL / 2 + j] * xi[QK4_NL / 2 + j];
296+
sumqx += w0 * v0 * xi[j] + w1 * v1 * xi[QK4_NL / 2 + j];
297+
sumq2 += w0 * v0 * v0 + w1 * v1 * v1;
298+
}
299+
300+
dsti->d = sumq2 > 0 ? sumqx / sumq2 : d;
301+
}
302+
303+
template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const char * cxi, char * cdsti) {
304+
float * cdstf = (float *) (cdsti);
305+
306+
for (int j = 0; j < qk / 2; j++) {
307+
dfloat2 dq;
308+
dequant(cxi, 0, j, dq);
309+
*(cdstf + j) = dq.x();
310+
*(cdstf + j + qk / 2) = dq.y();
311+
}
312+
}
313+
165314
template <cpy_kernel_t cpy_blck, int qk>
166315
static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
167316
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
@@ -188,6 +337,32 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00
188337
cpy_blck(cx + x_offset, cdst + dst_offset);
189338
}
190339

340+
template <cpy_kernel_t cpy_blck, int qk>
341+
static void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
342+
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
343+
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
344+
const sycl::nd_item<3> & item_ct1) {
345+
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
346+
347+
if (i >= ne) {
348+
return;
349+
}
350+
351+
const int i03 = i / (ne00 * ne01 * ne02);
352+
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
353+
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
354+
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
355+
const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
356+
357+
const int i13 = i / (ne10 * ne11 * ne12);
358+
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
359+
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
360+
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
361+
const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
362+
363+
cpy_blck(cx + x_offset, cdst + dst_offset);
364+
}
365+
191366
static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
192367
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
193368
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
@@ -255,6 +430,18 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
255430
});
256431
}
257432

433+
static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
434+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
435+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
436+
const int nb12, const int nb13, queue_ptr stream) {
437+
const int num_blocks = ne;
438+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
439+
[=](sycl::nd_item<3> item_ct1) {
440+
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
441+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
442+
});
443+
}
444+
258445
static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
259446
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
260447
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
@@ -268,6 +455,19 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
268455
});
269456
}
270457

458+
static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
459+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
460+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
461+
const int nb12, const int nb13, queue_ptr stream) {
462+
const int num_blocks = ne;
463+
stream->parallel_for(
464+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
465+
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
466+
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
467+
item_ct1);
468+
});
469+
}
470+
271471
static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
272472
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
273473
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
@@ -281,6 +481,84 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
281481
});
282482
}
283483

484+
static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
485+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
486+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
487+
const int nb12, const int nb13, queue_ptr stream) {
488+
const int num_blocks = ne;
489+
stream->parallel_for(
490+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
491+
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
492+
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
493+
item_ct1);
494+
});
495+
}
496+
497+
static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
498+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
499+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
500+
const int nb12, const int nb13, queue_ptr stream) {
501+
GGML_ASSERT(ne % QK5_0 == 0);
502+
const int num_blocks = ne / QK5_0;
503+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
504+
[=](sycl::nd_item<3> item_ct1) {
505+
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
506+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
507+
});
508+
}
509+
510+
static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
511+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
512+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
513+
const int nb12, const int nb13, queue_ptr stream) {
514+
const int num_blocks = ne;
515+
stream->parallel_for(
516+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
517+
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
518+
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
519+
item_ct1);
520+
});
521+
}
522+
523+
static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
524+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
525+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
526+
const int nb12, const int nb13, queue_ptr stream) {
527+
GGML_ASSERT(ne % QK5_1 == 0);
528+
const int num_blocks = ne / QK5_1;
529+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
530+
[=](sycl::nd_item<3> item_ct1) {
531+
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
532+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
533+
});
534+
}
535+
536+
static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
537+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
538+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
539+
const int nb12, const int nb13, queue_ptr stream) {
540+
const int num_blocks = ne;
541+
stream->parallel_for(
542+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
543+
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
544+
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
545+
item_ct1);
546+
});
547+
}
548+
549+
static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
550+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
551+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
552+
const int nb12, const int nb13, queue_ptr stream) {
553+
GGML_ASSERT(ne % QK4_NL == 0);
554+
const int num_blocks = ne / QK4_NL;
555+
stream->parallel_for(
556+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
557+
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
558+
ne12, nb10, nb11, nb12, nb13, item_ct1);
559+
});
560+
}
561+
284562
static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
285563
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
286564
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
@@ -379,6 +657,30 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
379657
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
380658
ggml_cpy_i32_i32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
381659
nb11, nb12, nb13, main_stream);
660+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
661+
ggml_cpy_q4_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
662+
nb11, nb12, nb13, main_stream);
663+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
664+
ggml_cpy_q4_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
665+
nb11, nb12, nb13, main_stream);
666+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
667+
ggml_cpy_q8_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
668+
nb11, nb12, nb13, main_stream);
669+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
670+
ggml_cpy_f32_q5_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
671+
nb11, nb12, nb13, main_stream);
672+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
673+
ggml_cpy_q5_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
674+
nb11, nb12, nb13, main_stream);
675+
} else if(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
676+
ggml_cpy_f32_q5_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
677+
nb11, nb12, nb13, main_stream);
678+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
679+
ggml_cpy_q5_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
680+
nb11, nb12, nb13, main_stream);
681+
} else if(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
682+
ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
683+
nb10, nb11, nb12, nb13, main_stream);
382684
} else {
383685
GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
384686
ggml_type_name(src1->type));
@@ -392,4 +694,4 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
392694
void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
393695
// TODO: why do we pass dst as src1 here?
394696
ggml_sycl_cpy(ctx, dst->src[0], dst);
395-
}
697+
}

ggml/src/ggml-sycl/cpy.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
88
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1);
99
void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
1010

11-
#endif // GGML_SYCL_CPY_HPP
11+
#endif // GGML_SYCL_CPY_HPP

0 commit comments

Comments
 (0)