@@ -4372,6 +4372,150 @@ kernel void kernel_concat(
43724372 }
43734373}
43744374
4375+ template <typename block_q, short qqk, void (*dequantize_func)(device const block_q *, device float *)>
4376+ kernel void kernel_cpy_q_f32 (
4377+ constant ggml_metal_kargs_cpy & args,
4378+ device const char * cx [[ buffer(1 ) ]],
4379+ device char * cdst [[ buffer(2 ) ]],
4380+ uint tid [[ thread_position_in_grid ]]
4381+ )
4382+ {
4383+ // Compute the global index multiplied by QK, matching:
4384+ // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4385+ const int i = int (tid) * qqk;
4386+
4387+ // Bounds check
4388+ if (i >= args.ne ) {
4389+ return ;
4390+ }
4391+
4392+ const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4393+ const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4394+ const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4395+ const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4396+ const int x_offset = (i00/qqk)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4397+
4398+ const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4399+ const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4400+ const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4401+ const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4402+ const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4403+
4404+ device const block_q * src_block = (device const block_q *)(cx + x_offset);
4405+ device float * dst = (device float *)(cdst + dst_offset);
4406+
4407+ dequantize_func (src_block, dst);
4408+ }
4409+
4410+ void dequant_q4_0_f (device const block_q4_0 * src_block, device float * dst) {
4411+ float d = float (src_block->d );
4412+ const float shift = 8 .0f ;
4413+
4414+ // Unpack 2 x 4-bit values per byte.
4415+ #pragma unroll(16)
4416+ for (int j = 0 ; j < QK4_0/2 ; j++) {
4417+ uint8_t q = src_block->qs [j];
4418+ uint8_t q0 = q & 0x0F ;
4419+ uint8_t q1 = (q >> 4 ) & 0x0F ;
4420+ dst[j] = (float (q0) - shift) * d;
4421+ dst[j + QK4_0/2 ] = (float (q1) - shift) * d;
4422+ }
4423+ }
4424+
4425+ void dequant_q4_1_f (device const block_q4_1 * src_block, device float * dst) {
4426+ float d = float (src_block->d );
4427+ float vmin = float (src_block->m );
4428+
4429+ #pragma unroll(16)
4430+ for (int j = 0 ; j < QK4_1/2 ; j++) {
4431+ uint8_t q = src_block->qs [j];
4432+ uint8_t q0 = q & 0x0F ;
4433+ uint8_t q1 = (q >> 4 ) & 0x0F ;
4434+ dst[j] = vmin + d * float (q0);
4435+ dst[j + QK4_1/2 ] = vmin + d * float (q1);
4436+ }
4437+ }
4438+
4439+ void dequant_q5_0_f (device const block_q5_0 * src_block, device float * dst) {
4440+ float d = float (src_block->d );
4441+ const float shift = 16 .f ;
4442+
4443+ // Combine the four qh bytes into a 32-bit value.
4444+ uint32_t qhVal = 0
4445+ | ((uint32_t ) src_block->qh [0 ] << 0 )
4446+ | ((uint32_t ) src_block->qh [1 ] << 8 )
4447+ | ((uint32_t ) src_block->qh [2 ] << 16 )
4448+ | ((uint32_t ) src_block->qh [3 ] << 24 );
4449+
4450+ // First half
4451+ #pragma unroll(16)
4452+ for (int j = 0 ; j < QK5_0/2 ; j++) {
4453+ uint8_t q = src_block->qs [j];
4454+ uint8_t lowNib = q & 0x0F ;
4455+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4456+ uint8_t qVal = (highBit << 4 ) | lowNib;
4457+ dst[j] = (float (qVal) - shift) * d;
4458+ }
4459+ // Second half
4460+ #pragma unroll(16)
4461+ for (int j = QK5_0/2 ; j < QK5_0; j++) {
4462+ int k = j - QK5_0/2 ;
4463+ uint8_t q = src_block->qs [k];
4464+ uint8_t hiNib = (q >> 4 ) & 0x0F ;
4465+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4466+ uint8_t qVal = (highBit << 4 ) | hiNib;
4467+ dst[j] = (float (qVal) - shift) * d;
4468+ }
4469+ }
4470+
4471+ void dequant_q5_1_f (device const block_q5_1 * src_block, device float * dst) {
4472+ float d = float (src_block->d );
4473+ float vmin = float (src_block->m );
4474+
4475+ uint32_t qhVal = 0
4476+ | ((uint32_t ) src_block->qh [0 ] << 0 )
4477+ | ((uint32_t ) src_block->qh [1 ] << 8 )
4478+ | ((uint32_t ) src_block->qh [2 ] << 16 )
4479+ | ((uint32_t ) src_block->qh [3 ] << 24 );
4480+
4481+ // First half
4482+ #pragma unroll(16)
4483+ for (int j = 0 ; j < QK5_1/2 ; j++) {
4484+ uint8_t q = src_block->qs [j];
4485+ uint8_t lowNib = q & 0x0F ;
4486+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4487+ uint8_t qVal = (highBit << 4 ) | lowNib;
4488+ dst[j] = vmin + d * float (qVal);
4489+ }
4490+ // Second half
4491+ #pragma unroll(16)
4492+ for (int j = QK5_1/2 ; j < QK5_1; j++) {
4493+ int k = j - QK5_1/2 ;
4494+ uint8_t q = src_block->qs [k];
4495+ uint8_t hiNib = (q >> 4 ) & 0x0F ;
4496+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4497+ uint8_t qVal = (highBit << 4 ) | hiNib;
4498+ dst[j] = vmin + d * float (qVal);
4499+ }
4500+ }
4501+
4502+ void dequant_q8_0_f (device const block_q8_0 * src_block, device float * dst) {
4503+ const float d = (float )src_block->d ;
4504+
4505+ #pragma unroll(32)
4506+ for (int j = 0 ; j < QK8_0; j++) {
4507+ dst[j] = src_block->qs [j] * d;
4508+ }
4509+ }
4510+
4511+ typedef decltype (kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>) cpy_q_t;
4512+
4513+ template [[host_name(" kernel_cpy_q4_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>;
4514+ template [[host_name(" kernel_cpy_q4_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_1, QK4_1, dequant_q4_1_f>;
4515+ template [[host_name(" kernel_cpy_q5_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_0, QK5_0, dequant_q5_0_f>;
4516+ template [[host_name(" kernel_cpy_q5_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_1, QK5_1, dequant_q5_1_f>;
4517+ template [[host_name(" kernel_cpy_q8_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q8_0, QK8_0, dequant_q8_0_f>;
4518+
43754519template <typename args_t >
43764520void kernel_mul_mv_q2_K_f32_impl (
43774521 args_t args,
0 commit comments