Skip to content

Commit 01d4b59

Browse files
committed
cuda: Add Q5_1, Q5_0, Q4_1 and Q4_0 to F32 conversion support. (#10976)
1 parent d04e716 commit 01d4b59

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,91 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225225
memcpy(dsti->qh, &qh, sizeof(qh));
226226
}
227227

228+
static __device__ void cpy_blck_q5_0_f32(const char * cxi, char * cdsti) {
229+
const block_q5_0 * xi = (const block_q5_0 *) cxi;
230+
float * dst = (float *) cdsti;
231+
float d = xi->d; // scale factor (computed as vmax / -16)
232+
const float shift = 16.0f;
233+
234+
// Safely copy the 32-bit qh value to avoid misaligned access.
235+
unsigned int qh;
236+
memcpy(&qh, xi->qh, sizeof(qh));
237+
238+
// First half: lower nibble stores element j.
239+
for (int j = 0; j < QK5_0/2; j++) {
240+
uint8_t lower = xi->qs[j] & 0xF;
241+
uint8_t high = (qh >> j) & 1;
242+
uint8_t q = (high << 4) | lower;
243+
dst[j] = ((float)q - shift) * d;
244+
}
245+
// Second half: upper nibble stores element j + QK5_0/2.
246+
for (int j = QK5_0/2; j < QK5_0; j++) {
247+
int k = j - QK5_0/2;
248+
uint8_t lower = (xi->qs[k] >> 4) & 0xF;
249+
uint8_t high = (qh >> j) & 1;
250+
uint8_t q = (high << 4) | lower;
251+
dst[j] = ((float)q - shift) * d;
252+
}
253+
}
254+
255+
static __device__ void cpy_blck_q5_1_f32(const char * cxi, char * cdsti) {
256+
const block_q5_1 * xi = (const block_q5_1 *) cxi;
257+
float * dst = (float *) cdsti;
258+
float d = xi->dm.x; // scale
259+
float min_val = xi->dm.y; // minimum value
260+
261+
// Safely copy the 32-bit qh value to avoid misaligned access.
262+
unsigned int qh;
263+
memcpy(&qh, xi->qh, sizeof(qh));
264+
265+
// Decode first half: lower nibble of xi->qs[j] holds element j.
266+
for (int j = 0; j < QK5_1/2; j++) {
267+
uint8_t lower = xi->qs[j] & 0xF;
268+
uint8_t high = (qh >> j) & 1;
269+
uint8_t q = (high << 4) | lower;
270+
dst[j] = min_val + d * (float)q;
271+
}
272+
// Decode second half: upper nibble of xi->qs[j] holds element j+QK5_1/2.
273+
for (int j = QK5_1/2; j < QK5_1; j++) {
274+
int k = j - QK5_1/2;
275+
uint8_t lower = (xi->qs[k] >> 4) & 0xF;
276+
uint8_t high = (qh >> j) & 1;
277+
uint8_t q = (high << 4) | lower;
278+
dst[j] = min_val + d * (float)q;
279+
}
280+
}
281+
282+
static __device__ void cpy_blck_q4_0_f32(const char * cxi, char * cdsti) {
283+
const block_q4_0 * xi = (const block_q4_0 *) cxi;
284+
float * dst = (float *) cdsti;
285+
float d = xi->d;
286+
const float shift = 8.0f;
287+
288+
// Each byte packs two 4-bit quantized values.
289+
for (int j = 0; j < QK4_0/2; j++) {
290+
uint8_t q_val = xi->qs[j];
291+
uint8_t q0 = q_val & 0x0F;
292+
uint8_t q1 = (q_val >> 4) & 0x0F;
293+
dst[j] = ((float)q0 - shift) * d;
294+
dst[j + QK4_0/2] = ((float)q1 - shift) * d;
295+
}
296+
}
297+
298+
static __device__ void cpy_blck_q4_1_f32(const char * cxi, char * cdsti) {
299+
const block_q4_1 * xi = (const block_q4_1 *) cxi;
300+
float * dst = (float *) cdsti;
301+
const float d = xi->dm.x;
302+
const float vmin = xi->dm.y;
303+
304+
// Each byte packs two 4-bit quantized values.
305+
for (int j = 0; j < QK4_1/2; ++j) {
306+
uint8_t byte_val = xi->qs[j];
307+
uint8_t q0 = byte_val & 0x0F;
308+
uint8_t q1 = (byte_val >> 4) & 0x0F;
309+
dst[j] = vmin + d * (float)q0;
310+
dst[j + QK4_1/2] = vmin + d * (float)q1;
311+
}
312+
}
228313

229314
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
230315
if (x <= val[0]) return 0;
@@ -420,6 +505,58 @@ static void ggml_cpy_f32_q5_1_cuda(
420505
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
421506
}
422507

508+
static void ggml_cpy_q5_1_f32_cuda(
509+
const char * cx, char * cdst, const int ne,
510+
const int ne00, const int ne01, const int ne02,
511+
const int nb00, const int nb01, const int nb02,
512+
const int nb03, const int ne10, const int ne11, const int ne12,
513+
const int nb10, const int nb11, const int nb12, const int nb13,
514+
cudaStream_t stream) {
515+
const int num_blocks = ne;
516+
cpy_q_f32<cpy_blck_q5_1_f32, QK5_1><<<num_blocks, 1, 0, stream>>>(
517+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
518+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
519+
}
520+
521+
static void ggml_cpy_q5_0_f32_cuda(
522+
const char * cx, char * cdst, const int ne,
523+
const int ne00, const int ne01, const int ne02,
524+
const int nb00, const int nb01, const int nb02,
525+
const int nb03, const int ne10, const int ne11, const int ne12,
526+
const int nb10, const int nb11, const int nb12, const int nb13,
527+
cudaStream_t stream) {
528+
const int num_blocks = ne;
529+
cpy_q_f32<cpy_blck_q5_0_f32, QK5_0><<<num_blocks, 1, 0, stream>>>(
530+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
531+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
532+
}
533+
534+
static void ggml_cpy_q4_1_f32_cuda(
535+
const char * cx, char * cdst, const int ne,
536+
const int ne00, const int ne01, const int ne02,
537+
const int nb00, const int nb01, const int nb02,
538+
const int nb03, const int ne10, const int ne11, const int ne12,
539+
const int nb10, const int nb11, const int nb12, const int nb13,
540+
cudaStream_t stream) {
541+
const int num_blocks = ne;
542+
cpy_q_f32<cpy_blck_q4_1_f32, QK4_1><<<num_blocks, 1, 0, stream>>>
543+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
544+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
545+
}
546+
547+
static void ggml_cpy_q4_0_f32_cuda(
548+
const char * cx, char * cdst, const int ne,
549+
const int ne00, const int ne01, const int ne02,
550+
const int nb00, const int nb01, const int nb02,
551+
const int nb03, const int ne10, const int ne11, const int ne12,
552+
const int nb10, const int nb11, const int nb12, const int nb13,
553+
cudaStream_t stream) {
554+
const int num_blocks = ne;
555+
cpy_q_f32<cpy_blck_q4_0_f32, QK4_0><<<num_blocks, 1, 0, stream>>>
556+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
557+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
558+
}
559+
423560
static void ggml_cpy_f32_iq4_nl_cuda(
424561
const char * cx, char * cdst, const int ne,
425562
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +625,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
488625
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489626
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
490627
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
628+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
629+
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
630+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
491631
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
492632
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
633+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
634+
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
635+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
493636
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
494637
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
638+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
639+
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
640+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
495641
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
496642
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
497643
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
498644
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
645+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
646+
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
499647
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500648
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
501649
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +672,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
524672
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
525673
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
526674
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
675+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
676+
return (void*) cpy_q_f32<cpy_blck_q4_0_f32, QK4_0>;
527677
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
528678
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
679+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
680+
return (void*) cpy_q_f32<cpy_blck_q4_1_f32, QK4_1>;
529681
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
530682
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
683+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
684+
return (void*) cpy_q_f32<cpy_blck_q5_0_f32, QK5_0>;
531685
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
532686
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
533687
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
534688
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
689+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
690+
return (void*) cpy_q_f32<cpy_blck_q5_1_f32, QK5_1>;
535691
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
536692
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
537693
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3073,15 +3073,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30733073
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
30743074
return true;
30753075
}
3076+
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
3077+
return true;
3078+
}
30763079
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
30773080
return true;
30783081
}
3082+
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
3083+
return true;
3084+
}
30793085
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
30803086
return true;
30813087
}
3088+
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
3089+
return true;
3090+
}
30823091
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
30833092
return true;
30843093
}
3094+
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
3095+
return true;
3096+
}
30853097
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
30863098
return true;
30873099
}

0 commit comments

Comments
 (0)