Skip to content

Commit 71cef96

Browse files
committed
metal: Copy kernels for quant to F32 conversions (#10976).
1 parent 01d4b59 commit 71cef96

File tree

3 files changed

+302
-5
lines changed

3 files changed

+302
-5
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ typedef struct {
8484
} ggml_metal_kargs_repeat;
8585

8686
typedef struct {
87+
int64_t ne;
8788
int64_t ne00;
8889
int64_t ne01;
8990
int64_t ne02;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
407407
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408408
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409409
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
412+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
413+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
414+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
410415
GGML_METAL_KERNEL_TYPE_CONCAT,
411416
GGML_METAL_KERNEL_TYPE_SQR,
412417
GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1017,11 @@ @implementation GGMLMetalClass
10121017
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
10131018
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
10141019
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1020+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1021+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1022+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1023+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1024+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
10151025
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
10161026
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
10171027
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
@@ -1287,6 +1297,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12871297
default:
12881298
return false;
12891299
}
1300+
case GGML_TYPE_Q4_0:
1301+
case GGML_TYPE_Q4_1:
1302+
case GGML_TYPE_Q5_0:
1303+
case GGML_TYPE_Q5_1:
1304+
case GGML_TYPE_Q8_0:
1305+
return (op->type == GGML_TYPE_F32);
12901306
default:
12911307
return false;
12921308
};
@@ -1615,7 +1631,10 @@ static void ggml_metal_encode_node(
16151631

16161632
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
16171633

1634+
const int64_t ne = ggml_nelements(src0);
1635+
16181636
ggml_metal_kargs_cpy args = {
1637+
/*.ne =*/ ne,
16191638
/*.ne00 =*/ ne00,
16201639
/*.ne01 =*/ ne01,
16211640
/*.ne02 =*/ ne02,
@@ -3899,10 +3918,7 @@ static void ggml_metal_encode_node(
38993918
case GGML_OP_CPY:
39003919
case GGML_OP_CONT:
39013920
{
3902-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3903-
3904-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
3905-
3921+
const int64_t ne = ggml_nelements(src0);
39063922
id<MTLComputePipelineState> pipeline = nil;
39073923

39083924
switch (src0t) {
@@ -3936,13 +3952,33 @@ static void ggml_metal_encode_node(
39363952
switch (dstt) {
39373953
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
39383954
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3939-
default: GGML_ASSERT(false && "not implemented");
3955+
default: GGML_ABORT("not implemented");
39403956
};
39413957
} break;
3958+
case GGML_TYPE_Q4_0:
3959+
case GGML_TYPE_Q4_1:
3960+
case GGML_TYPE_Q5_0:
3961+
case GGML_TYPE_Q5_1:
3962+
case GGML_TYPE_Q8_0:
3963+
{
3964+
if (dstt == GGML_TYPE_F32) {
3965+
switch (src0t) {
3966+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3967+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3968+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3969+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3970+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
3971+
default: GGML_ABORT("not implemented");
3972+
}
3973+
} else {
3974+
GGML_ABORT("not implemented");
3975+
}
3976+
} break;
39423977
default: GGML_ABORT("not implemented");
39433978
}
39443979

39453980
ggml_metal_kargs_cpy args = {
3981+
/*.ne =*/ ne,
39463982
/*.ne00 =*/ ne00,
39473983
/*.ne01 =*/ ne01,
39483984
/*.ne02 =*/ ne02,
@@ -3966,7 +4002,17 @@ static void ggml_metal_encode_node(
39664002
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
39674003
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
39684004

4005+
int nth;
4006+
4007+
if (src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
4008+
GGML_ASSERT(dstt == GGML_TYPE_F32);
4009+
nth = MIN(1024, ne);
4010+
} else {
4011+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4012+
nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4013+
}
39694014
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4015+
39704016
} break;
39714017
case GGML_OP_SET:
39724018
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4372,6 +4372,256 @@ kernel void kernel_concat(
43724372
}
43734373
}
43744374

4375+
kernel void kernel_cpy_q4_0_f32(
4376+
constant ggml_metal_kargs_cpy & args,
4377+
device const char *cx [[ buffer(1) ]],
4378+
device char *cdst [[ buffer(2) ]],
4379+
uint tid [[ thread_position_in_grid ]]
4380+
)
4381+
{
4382+
// Compute the global index multiplied by QK, matching:
4383+
// i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4384+
const int i = int(tid) * QK4_0;
4385+
4386+
// Bounds check
4387+
if (i >= args.ne) {
4388+
return;
4389+
}
4390+
4391+
const int i03 = i/(args.ne00 * args.ne01 * args.ne02);
4392+
const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01);
4393+
const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00;
4394+
const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00;
4395+
const int x_offset = (i00/QK4_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03;
4396+
4397+
const int i13 = i/(args.ne0 * args.ne1 * args.ne2);
4398+
const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1);
4399+
const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0;
4400+
const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0;
4401+
const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3;
4402+
4403+
device const block_q4_0 * src_block = (device const block_q4_0 *)(cx + x_offset);
4404+
device float * dst = (device float *)(cdst + dst_offset);
4405+
4406+
float d = float(src_block->d);
4407+
const float shift = 8.0f;
4408+
4409+
// Unpack 2 x 4-bit values per byte.
4410+
for (int j = 0; j < QK4_0/2; j++) {
4411+
uint8_t q = src_block->qs[j];
4412+
uint8_t q0 = q & 0x0F;
4413+
uint8_t q1 = (q >> 4) & 0x0F;
4414+
dst[j] = (float(q0) - shift) * d;
4415+
dst[j + QK4_0/2] = (float(q1) - shift) * d;
4416+
}
4417+
}
4418+
4419+
kernel void kernel_cpy_q4_1_f32(
4420+
constant ggml_metal_kargs_cpy & args,
4421+
device const char *cx [[ buffer(1) ]],
4422+
device char *cdst [[ buffer(2) ]],
4423+
uint tid [[ thread_position_in_grid ]]
4424+
)
4425+
{
4426+
// Compute the global index multiplied by QK, matching:
4427+
// i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4428+
const int i = int(tid) * QK4_1;
4429+
4430+
// Bounds check
4431+
if (i >= args.ne) {
4432+
return;
4433+
}
4434+
4435+
const int i03 = i/(args.ne00 * args.ne01 * args.ne02);
4436+
const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01);
4437+
const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00;
4438+
const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00;
4439+
const int x_offset = (i00/QK4_1)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03;
4440+
4441+
const int i13 = i/(args.ne0 * args.ne1 * args.ne2);
4442+
const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1);
4443+
const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0;
4444+
const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0;
4445+
const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3;
4446+
4447+
device const block_q4_1 * src_block = (device const block_q4_1 *)(cx + x_offset);
4448+
device float * dst = (device float *)(cdst + dst_offset);
4449+
4450+
float d = float(src_block->d);
4451+
float vmin = float(src_block->m);
4452+
4453+
for (int j = 0; j < QK4_1/2; j++) {
4454+
uint8_t q = src_block->qs[j];
4455+
uint8_t q0 = q & 0x0F;
4456+
uint8_t q1 = (q >> 4) & 0x0F;
4457+
dst[j] = vmin + d * float(q0);
4458+
dst[j + QK4_1/2] = vmin + d * float(q1);
4459+
}
4460+
}
4461+
4462+
4463+
kernel void kernel_cpy_q5_0_f32(
4464+
constant ggml_metal_kargs_cpy & args,
4465+
device const char *cx [[ buffer(1) ]],
4466+
device char *cdst [[ buffer(2) ]],
4467+
uint tid [[ thread_position_in_grid ]]
4468+
)
4469+
{
4470+
// Compute the global index multiplied by QK, matching:
4471+
// i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4472+
const int i = int(tid) * QK5_0;
4473+
4474+
// Bounds check
4475+
if (i >= args.ne) {
4476+
return;
4477+
}
4478+
4479+
const int i03 = i/(args.ne00 * args.ne01 * args.ne02);
4480+
const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01);
4481+
const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00;
4482+
const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00;
4483+
const int x_offset = (i00/QK5_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03;
4484+
4485+
const int i13 = i/(args.ne0 * args.ne1 * args.ne2);
4486+
const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1);
4487+
const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0;
4488+
const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0;
4489+
const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3;
4490+
4491+
device const block_q5_0 * src_block = (device const block_q5_0 *)(cx + x_offset);
4492+
device float * dst = (device float *)(cdst + dst_offset);
4493+
4494+
float d = float(src_block->d);
4495+
const float shift = 16.f;
4496+
4497+
// Combine the four qh bytes into a 32-bit value.
4498+
uint32_t qhVal = 0
4499+
| ((uint32_t) src_block->qh[0] << 0)
4500+
| ((uint32_t) src_block->qh[1] << 8)
4501+
| ((uint32_t) src_block->qh[2] << 16)
4502+
| ((uint32_t) src_block->qh[3] << 24);
4503+
4504+
// First half
4505+
for (int j = 0; j < QK5_0/2; j++) {
4506+
uint8_t q = src_block->qs[j];
4507+
uint8_t lowNib = q & 0x0F;
4508+
uint8_t highBit = (qhVal >> j) & 0x1;
4509+
uint8_t qVal = (highBit << 4) | lowNib;
4510+
dst[j] = (float(qVal) - shift) * d;
4511+
}
4512+
// Second half
4513+
for (int j = QK5_0/2; j < QK5_0; j++) {
4514+
int k = j - QK5_0/2;
4515+
uint8_t q = src_block->qs[k];
4516+
uint8_t hiNib = (q >> 4) & 0x0F;
4517+
uint8_t highBit = (qhVal >> j) & 0x1;
4518+
uint8_t qVal = (highBit << 4) | hiNib;
4519+
dst[j] = (float(qVal) - shift) * d;
4520+
}
4521+
}
4522+
4523+
4524+
kernel void kernel_cpy_q5_1_f32(
4525+
constant ggml_metal_kargs_cpy & args,
4526+
device const char *cx [[ buffer(1) ]],
4527+
device char *cdst [[ buffer(2) ]],
4528+
uint tid [[ thread_position_in_grid ]]
4529+
)
4530+
{
4531+
// Compute the global index multiplied by QK, matching:
4532+
// i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4533+
const int i = int(tid) * QK5_1;
4534+
4535+
// Bounds check
4536+
if (i >= args.ne) {
4537+
return;
4538+
}
4539+
4540+
const int i03 = i/(args.ne00 * args.ne01 * args.ne02);
4541+
const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01);
4542+
const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00;
4543+
const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00;
4544+
const int x_offset = (i00/QK5_1)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03;
4545+
4546+
const int i13 = i/(args.ne0 * args.ne1 * args.ne2);
4547+
const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1);
4548+
const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0;
4549+
const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0;
4550+
const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3;
4551+
4552+
device const block_q5_1 * src_block = (device const block_q5_1 *)(cx + x_offset);
4553+
device float * dst = (device float *)(cdst + dst_offset);
4554+
4555+
float d = float(src_block->d);
4556+
float vmin = float(src_block->m);
4557+
4558+
uint32_t qhVal = 0
4559+
| ((uint32_t) src_block->qh[0] << 0)
4560+
| ((uint32_t) src_block->qh[1] << 8)
4561+
| ((uint32_t) src_block->qh[2] << 16)
4562+
| ((uint32_t) src_block->qh[3] << 24);
4563+
4564+
// First half
4565+
for (int j = 0; j < QK5_1/2; j++) {
4566+
uint8_t q = src_block->qs[j];
4567+
uint8_t lowNib = q & 0x0F;
4568+
uint8_t highBit = (qhVal >> j) & 0x1;
4569+
uint8_t qVal = (highBit << 4) | lowNib;
4570+
dst[j] = vmin + d * float(qVal);
4571+
}
4572+
// Second half
4573+
for (int j = QK5_1/2; j < QK5_1; j++) {
4574+
int k = j - QK5_1/2;
4575+
uint8_t q = src_block->qs[k];
4576+
uint8_t hiNib = (q >> 4) & 0x0F;
4577+
uint8_t highBit = (qhVal >> j) & 0x1;
4578+
uint8_t qVal = (highBit << 4) | hiNib;
4579+
dst[j] = vmin + d * float(qVal);
4580+
}
4581+
}
4582+
4583+
kernel void kernel_cpy_q8_0_f32(
4584+
constant ggml_metal_kargs_cpy &args [[ buffer(0) ]],
4585+
device const char *cx [[ buffer(1) ]],
4586+
device char *cdst [[ buffer(2) ]],
4587+
uint tid [[ thread_position_in_grid ]]
4588+
) {
4589+
// Compute the global index multiplied by QK, matching:
4590+
// i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4591+
const int i = int(tid) * QK8_0;
4592+
4593+
// Bounds check
4594+
if (i >= args.ne) {
4595+
return;
4596+
}
4597+
4598+
const int i03 = i/(args.ne00 * args.ne01 * args.ne02);
4599+
const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01);
4600+
const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00;
4601+
const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00;
4602+
const int x_offset = (i00/QK8_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03;
4603+
4604+
const int i13 = i/(args.ne0 * args.ne1 * args.ne2);
4605+
const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1);
4606+
const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0;
4607+
const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0;
4608+
const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3;
4609+
4610+
// Call the device function that performs the copy/dequantization.
4611+
// cpy_blck(cx + x_offset, cdst + dst_offset);
4612+
device const char * src_block = cx + x_offset;
4613+
device char * dst = cdst + dst_offset;
4614+
4615+
const device block_q8_0 * xi = (device const block_q8_0 *) src_block;
4616+
device float * dsti = (device float *) dst;
4617+
4618+
const float d = (float)xi->d;
4619+
4620+
for (int j = 0; j < QK8_0; j++) {
4621+
dsti[j] = xi->qs[j] * d;
4622+
}
4623+
}
4624+
43754625
template<typename args_t>
43764626
void kernel_mul_mv_q2_K_f32_impl(
43774627
args_t args,

0 commit comments

Comments
 (0)