@@ -4372,6 +4372,256 @@ kernel void kernel_concat(
43724372 }
43734373}
43744374
4375+ kernel void kernel_cpy_q4_0_f32 (
4376+ constant ggml_metal_kargs_cpy & args,
4377+ device const char *cx [[ buffer(1 ) ]],
4378+ device char *cdst [[ buffer(2 ) ]],
4379+ uint tid [[ thread_position_in_grid ]]
4380+ )
4381+ {
4382+ // Compute the global index multiplied by QK, matching:
4383+ // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4384+ const int i = int (tid) * QK4_0;
4385+
4386+ // Bounds check
4387+ if (i >= args.ne ) {
4388+ return ;
4389+ }
4390+
4391+ const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4392+ const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4393+ const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4394+ const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4395+ const int x_offset = (i00/QK4_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4396+
4397+ const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4398+ const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4399+ const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4400+ const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4401+ const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4402+
4403+ device const block_q4_0 * src_block = (device const block_q4_0 *)(cx + x_offset);
4404+ device float * dst = (device float *)(cdst + dst_offset);
4405+
4406+ float d = float (src_block->d );
4407+ const float shift = 8 .0f ;
4408+
4409+ // Unpack 2 x 4-bit values per byte.
4410+ for (int j = 0 ; j < QK4_0/2 ; j++) {
4411+ uint8_t q = src_block->qs [j];
4412+ uint8_t q0 = q & 0x0F ;
4413+ uint8_t q1 = (q >> 4 ) & 0x0F ;
4414+ dst[j] = (float (q0) - shift) * d;
4415+ dst[j + QK4_0/2 ] = (float (q1) - shift) * d;
4416+ }
4417+ }
4418+
4419+ kernel void kernel_cpy_q4_1_f32 (
4420+ constant ggml_metal_kargs_cpy & args,
4421+ device const char *cx [[ buffer(1 ) ]],
4422+ device char *cdst [[ buffer(2 ) ]],
4423+ uint tid [[ thread_position_in_grid ]]
4424+ )
4425+ {
4426+ // Compute the global index multiplied by QK, matching:
4427+ // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4428+ const int i = int (tid) * QK4_1;
4429+
4430+ // Bounds check
4431+ if (i >= args.ne ) {
4432+ return ;
4433+ }
4434+
4435+ const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4436+ const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4437+ const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4438+ const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4439+ const int x_offset = (i00/QK4_1)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4440+
4441+ const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4442+ const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4443+ const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4444+ const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4445+ const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4446+
4447+ device const block_q4_1 * src_block = (device const block_q4_1 *)(cx + x_offset);
4448+ device float * dst = (device float *)(cdst + dst_offset);
4449+
4450+ float d = float (src_block->d );
4451+ float vmin = float (src_block->m );
4452+
4453+ for (int j = 0 ; j < QK4_1/2 ; j++) {
4454+ uint8_t q = src_block->qs [j];
4455+ uint8_t q0 = q & 0x0F ;
4456+ uint8_t q1 = (q >> 4 ) & 0x0F ;
4457+ dst[j] = vmin + d * float (q0);
4458+ dst[j + QK4_1/2 ] = vmin + d * float (q1);
4459+ }
4460+ }
4461+
4462+
4463+ kernel void kernel_cpy_q5_0_f32 (
4464+ constant ggml_metal_kargs_cpy & args,
4465+ device const char *cx [[ buffer(1 ) ]],
4466+ device char *cdst [[ buffer(2 ) ]],
4467+ uint tid [[ thread_position_in_grid ]]
4468+ )
4469+ {
4470+ // Compute the global index multiplied by QK, matching:
4471+ // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4472+ const int i = int (tid) * QK5_0;
4473+
4474+ // Bounds check
4475+ if (i >= args.ne ) {
4476+ return ;
4477+ }
4478+
4479+ const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4480+ const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4481+ const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4482+ const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4483+ const int x_offset = (i00/QK5_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4484+
4485+ const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4486+ const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4487+ const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4488+ const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4489+ const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4490+
4491+ device const block_q5_0 * src_block = (device const block_q5_0 *)(cx + x_offset);
4492+ device float * dst = (device float *)(cdst + dst_offset);
4493+
4494+ float d = float (src_block->d );
4495+ const float shift = 16 .f ;
4496+
4497+ // Combine the four qh bytes into a 32-bit value.
4498+ uint32_t qhVal = 0
4499+ | ((uint32_t ) src_block->qh [0 ] << 0 )
4500+ | ((uint32_t ) src_block->qh [1 ] << 8 )
4501+ | ((uint32_t ) src_block->qh [2 ] << 16 )
4502+ | ((uint32_t ) src_block->qh [3 ] << 24 );
4503+
4504+ // First half
4505+ for (int j = 0 ; j < QK5_0/2 ; j++) {
4506+ uint8_t q = src_block->qs [j];
4507+ uint8_t lowNib = q & 0x0F ;
4508+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4509+ uint8_t qVal = (highBit << 4 ) | lowNib;
4510+ dst[j] = (float (qVal) - shift) * d;
4511+ }
4512+ // Second half
4513+ for (int j = QK5_0/2 ; j < QK5_0; j++) {
4514+ int k = j - QK5_0/2 ;
4515+ uint8_t q = src_block->qs [k];
4516+ uint8_t hiNib = (q >> 4 ) & 0x0F ;
4517+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4518+ uint8_t qVal = (highBit << 4 ) | hiNib;
4519+ dst[j] = (float (qVal) - shift) * d;
4520+ }
4521+ }
4522+
4523+
4524+ kernel void kernel_cpy_q5_1_f32 (
4525+ constant ggml_metal_kargs_cpy & args,
4526+ device const char *cx [[ buffer(1 ) ]],
4527+ device char *cdst [[ buffer(2 ) ]],
4528+ uint tid [[ thread_position_in_grid ]]
4529+ )
4530+ {
4531+ // Compute the global index multiplied by QK, matching:
4532+ // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4533+ const int i = int (tid) * QK5_1;
4534+
4535+ // Bounds check
4536+ if (i >= args.ne ) {
4537+ return ;
4538+ }
4539+
4540+ const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4541+ const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4542+ const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4543+ const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4544+ const int x_offset = (i00/QK5_1)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4545+
4546+ const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4547+ const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4548+ const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4549+ const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4550+ const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4551+
4552+ device const block_q5_1 * src_block = (device const block_q5_1 *)(cx + x_offset);
4553+ device float * dst = (device float *)(cdst + dst_offset);
4554+
4555+ float d = float (src_block->d );
4556+ float vmin = float (src_block->m );
4557+
4558+ uint32_t qhVal = 0
4559+ | ((uint32_t ) src_block->qh [0 ] << 0 )
4560+ | ((uint32_t ) src_block->qh [1 ] << 8 )
4561+ | ((uint32_t ) src_block->qh [2 ] << 16 )
4562+ | ((uint32_t ) src_block->qh [3 ] << 24 );
4563+
4564+ // First half
4565+ for (int j = 0 ; j < QK5_1/2 ; j++) {
4566+ uint8_t q = src_block->qs [j];
4567+ uint8_t lowNib = q & 0x0F ;
4568+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4569+ uint8_t qVal = (highBit << 4 ) | lowNib;
4570+ dst[j] = vmin + d * float (qVal);
4571+ }
4572+ // Second half
4573+ for (int j = QK5_1/2 ; j < QK5_1; j++) {
4574+ int k = j - QK5_1/2 ;
4575+ uint8_t q = src_block->qs [k];
4576+ uint8_t hiNib = (q >> 4 ) & 0x0F ;
4577+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4578+ uint8_t qVal = (highBit << 4 ) | hiNib;
4579+ dst[j] = vmin + d * float (qVal);
4580+ }
4581+ }
4582+
4583+ kernel void kernel_cpy_q8_0_f32 (
4584+ constant ggml_metal_kargs_cpy &args [[ buffer(0 ) ]],
4585+ device const char *cx [[ buffer(1 ) ]],
4586+ device char *cdst [[ buffer(2 ) ]],
4587+ uint tid [[ thread_position_in_grid ]]
4588+ ) {
4589+ // Compute the global index multiplied by QK, matching:
4590+ // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4591+ const int i = int (tid) * QK8_0;
4592+
4593+ // Bounds check
4594+ if (i >= args.ne ) {
4595+ return ;
4596+ }
4597+
4598+ const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4599+ const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4600+ const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4601+ const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4602+ const int x_offset = (i00/QK8_0)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4603+
4604+ const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4605+ const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4606+ const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4607+ const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4608+ const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4609+
4610+ // Call the device function that performs the copy/dequantization.
4611+ // cpy_blck(cx + x_offset, cdst + dst_offset);
4612+ device const char * src_block = cx + x_offset;
4613+ device char * dst = cdst + dst_offset;
4614+
4615+ const device block_q8_0 * xi = (device const block_q8_0 *) src_block;
4616+ device float * dsti = (device float *) dst;
4617+
4618+ const float d = (float )xi->d ;
4619+
4620+ for (int j = 0 ; j < QK8_0; j++) {
4621+ dsti[j] = xi->qs [j] * d;
4622+ }
4623+ }
4624+
43754625template <typename args_t >
43764626void kernel_mul_mv_q2_K_f32_impl (
43774627 args_t args,
0 commit comments