@@ -4372,6 +4372,143 @@ kernel void kernel_concat(
43724372 }
43734373}
43744374
4375+ template <typename block_q, short qqk, void (*dequantize_func)(device const block_q *, device float *)>
4376+ kernel void kernel_cpy_q_f32 (
4377+ constant ggml_metal_kargs_cpy & args,
4378+ device const char * cx [[ buffer(1 ) ]],
4379+ device char * cdst [[ buffer(2 ) ]],
4380+ uint tid [[ thread_position_in_grid ]]
4381+ )
4382+ {
4383+ // Compute the global index multiplied by QK, matching:
4384+ // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4385+ const int i = int (tid) * qqk;
4386+
4387+ // Bounds check
4388+ if (i >= args.ne ) {
4389+ return ;
4390+ }
4391+
4392+ const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4393+ const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4394+ const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4395+ const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4396+ const int x_offset = (i00/qqk)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4397+
4398+ const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4399+ const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4400+ const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4401+ const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4402+ const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4403+
4404+ device const block_q * src_block = (device const block_q *)(cx + x_offset);
4405+ device float * dst = (device float *)(cdst + dst_offset);
4406+
4407+ dequantize_func (src_block, dst);
4408+ }
4409+
4410+ void dequant_q4_0_f (device const block_q4_0 * src_block, device float * dst) {
4411+ float d = float (src_block->d );
4412+ const float shift = 8 .0f ;
4413+
4414+ // Unpack 2 x 4-bit values per byte.
4415+ for (int j = 0 ; j < QK4_0/2 ; j++) {
4416+ uint8_t q = src_block->qs [j];
4417+ uint8_t q0 = q & 0x0F ;
4418+ uint8_t q1 = (q >> 4 ) & 0x0F ;
4419+ dst[j] = (float (q0) - shift) * d;
4420+ dst[j + QK4_0/2 ] = (float (q1) - shift) * d;
4421+ }
4422+ }
4423+
4424+ void dequant_q4_1_f (device const block_q4_1 * src_block, device float * dst) {
4425+ float d = float (src_block->d );
4426+ float vmin = float (src_block->m );
4427+
4428+ for (int j = 0 ; j < QK4_1/2 ; j++) {
4429+ uint8_t q = src_block->qs [j];
4430+ uint8_t q0 = q & 0x0F ;
4431+ uint8_t q1 = (q >> 4 ) & 0x0F ;
4432+ dst[j] = vmin + d * float (q0);
4433+ dst[j + QK4_1/2 ] = vmin + d * float (q1);
4434+ }
4435+ }
4436+
4437+ void dequant_q5_0_f (device const block_q5_0 * src_block, device float * dst) {
4438+ float d = float (src_block->d );
4439+ const float shift = 16 .f ;
4440+
4441+ // Combine the four qh bytes into a 32-bit value.
4442+ uint32_t qhVal = 0
4443+ | ((uint32_t ) src_block->qh [0 ] << 0 )
4444+ | ((uint32_t ) src_block->qh [1 ] << 8 )
4445+ | ((uint32_t ) src_block->qh [2 ] << 16 )
4446+ | ((uint32_t ) src_block->qh [3 ] << 24 );
4447+
4448+ // First half
4449+ for (int j = 0 ; j < QK5_0/2 ; j++) {
4450+ uint8_t q = src_block->qs [j];
4451+ uint8_t lowNib = q & 0x0F ;
4452+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4453+ uint8_t qVal = (highBit << 4 ) | lowNib;
4454+ dst[j] = (float (qVal) - shift) * d;
4455+ }
4456+ // Second half
4457+ for (int j = QK5_0/2 ; j < QK5_0; j++) {
4458+ int k = j - QK5_0/2 ;
4459+ uint8_t q = src_block->qs [k];
4460+ uint8_t hiNib = (q >> 4 ) & 0x0F ;
4461+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4462+ uint8_t qVal = (highBit << 4 ) | hiNib;
4463+ dst[j] = (float (qVal) - shift) * d;
4464+ }
4465+ }
4466+
4467+ void dequant_q5_1_f (device const block_q5_1 * src_block, device float * dst) {
4468+ float d = float (src_block->d );
4469+ float vmin = float (src_block->m );
4470+
4471+ uint32_t qhVal = 0
4472+ | ((uint32_t ) src_block->qh [0 ] << 0 )
4473+ | ((uint32_t ) src_block->qh [1 ] << 8 )
4474+ | ((uint32_t ) src_block->qh [2 ] << 16 )
4475+ | ((uint32_t ) src_block->qh [3 ] << 24 );
4476+
4477+ // First half
4478+ for (int j = 0 ; j < QK5_1/2 ; j++) {
4479+ uint8_t q = src_block->qs [j];
4480+ uint8_t lowNib = q & 0x0F ;
4481+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4482+ uint8_t qVal = (highBit << 4 ) | lowNib;
4483+ dst[j] = vmin + d * float (qVal);
4484+ }
4485+ // Second half
4486+ for (int j = QK5_1/2 ; j < QK5_1; j++) {
4487+ int k = j - QK5_1/2 ;
4488+ uint8_t q = src_block->qs [k];
4489+ uint8_t hiNib = (q >> 4 ) & 0x0F ;
4490+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4491+ uint8_t qVal = (highBit << 4 ) | hiNib;
4492+ dst[j] = vmin + d * float (qVal);
4493+ }
4494+ }
4495+
4496+ void dequant_q8_0_f (device const block_q8_0 * src_block, device float * dst) {
4497+ const float d = (float )src_block->d ;
4498+
4499+ for (int j = 0 ; j < QK8_0; j++) {
4500+ dst[j] = src_block->qs [j] * d;
4501+ }
4502+ }
4503+
4504+ typedef decltype (kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>) cpy_q_t;
4505+
4506+ template [[host_name(" kernel_cpy_q4_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>;
4507+ template [[host_name(" kernel_cpy_q4_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_1, QK4_1, dequant_q4_1_f>;
4508+ template [[host_name(" kernel_cpy_q5_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_0, QK5_0, dequant_q5_0_f>;
4509+ template [[host_name(" kernel_cpy_q5_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_1, QK5_1, dequant_q5_1_f>;
4510+ template [[host_name(" kernel_cpy_q8_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q8_0, QK8_0, dequant_q8_0_f>;
4511+
43754512template <typename args_t >
43764513void kernel_mul_mv_q2_K_f32_impl (
43774514 args_t args,
0 commit comments