@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
43414341 }
43424342}
43434343
4344+ template <typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread T4x4 &)>
4345+ kernel void kernel_cpy_q_f32 (
4346+ constant ggml_metal_kargs_cpy & args,
4347+ device const char * src0,
4348+ device char * dst,
4349+ uint3 tgpig[[threadgroup_position_in_grid]],
4350+ ushort3 tpitg[[thread_position_in_threadgroup]],
4351+ ushort3 ntg[[threads_per_threadgroup]]) {
4352+ const int i03 = tgpig[2 ];
4353+ const int i02 = tgpig[1 ];
4354+ const int i01 = tgpig[0 ];
4355+
4356+ const int64_t n = i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 ;
4357+
4358+ const int64_t i3 = n/(args.ne2 *args.ne1 *args.ne0 );
4359+ const int64_t i2 = (n - i3*args.ne2 *args.ne1 *args.ne0 )/(args.ne1 *args.ne0 );
4360+ const int64_t i1 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 )/args.ne0 ;
4361+ const int64_t i0 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 - i1*args.ne0 );
4362+
4363+ device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 );
4364+ device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
4365+
4366+ for (int64_t i00 = tpitg.x ; i00 < args.ne00 /16 ; i00 += ntg.x ) {
4367+ T4x4 temp;
4368+ dequantize_func (src_data + i00/nl, i00%nl, temp);
4369+ dst_data[i00] = temp;
4370+ }
4371+ }
4372+
4373+ typedef decltype (kernel_cpy_q_f32<float4x4, block_q4_0, 2 , dequantize_q4_0>) cpy_q_f_t;
4374+
4375+ template [[host_name(" kernel_cpy_q4_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2 , dequantize_q4_0>;
4376+ template [[host_name(" kernel_cpy_q4_1_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2 , dequantize_q4_1>;
4377+ template [[host_name(" kernel_cpy_q5_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2 , dequantize_q5_0>;
4378+ template [[host_name(" kernel_cpy_q5_1_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2 , dequantize_q5_1>;
4379+ template [[host_name(" kernel_cpy_q8_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2 , dequantize_q8_0>;
4380+
4381+ template [[host_name(" kernel_cpy_q4_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2 , dequantize_q4_0>;
4382+ template [[host_name(" kernel_cpy_q4_1_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2 , dequantize_q4_1>;
4383+ template [[host_name(" kernel_cpy_q5_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2 , dequantize_q5_0>;
4384+ template [[host_name(" kernel_cpy_q5_1_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2 , dequantize_q5_1>;
4385+ template [[host_name(" kernel_cpy_q8_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2 , dequantize_q8_0>;
4386+
43444387kernel void kernel_concat (
43454388 constant ggml_metal_kargs_concat & args,
43464389 device const char * src0,
@@ -4372,150 +4415,6 @@ kernel void kernel_concat(
43724415 }
43734416}
43744417
4375- template <typename block_q, short qqk, void (*dequantize_func)(device const block_q *, device float *)>
4376- kernel void kernel_cpy_q_f32 (
4377- constant ggml_metal_kargs_cpy & args,
4378- device const char * cx [[ buffer(1 ) ]],
4379- device char * cdst [[ buffer(2 ) ]],
4380- uint tid [[ thread_position_in_grid ]]
4381- )
4382- {
4383- // Compute the global index multiplied by QK, matching:
4384- // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4385- const int i = int (tid) * qqk;
4386-
4387- // Bounds check
4388- if (i >= args.ne ) {
4389- return ;
4390- }
4391-
4392- const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4393- const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4394- const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4395- const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4396- const int x_offset = (i00/qqk)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4397-
4398- const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4399- const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4400- const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4401- const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4402- const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4403-
4404- device const block_q * src_block = (device const block_q *)(cx + x_offset);
4405- device float * dst = (device float *)(cdst + dst_offset);
4406-
4407- dequantize_func (src_block, dst);
4408- }
4409-
4410- void dequant_q4_0_f (device const block_q4_0 * src_block, device float * dst) {
4411- float d = float (src_block->d );
4412- const float shift = 8 .0f ;
4413-
4414- // Unpack 2 x 4-bit values per byte.
4415- #pragma unroll(16)
4416- for (int j = 0 ; j < QK4_0/2 ; j++) {
4417- uint8_t q = src_block->qs [j];
4418- uint8_t q0 = q & 0x0F ;
4419- uint8_t q1 = (q >> 4 ) & 0x0F ;
4420- dst[j] = (float (q0) - shift) * d;
4421- dst[j + QK4_0/2 ] = (float (q1) - shift) * d;
4422- }
4423- }
4424-
4425- void dequant_q4_1_f (device const block_q4_1 * src_block, device float * dst) {
4426- float d = float (src_block->d );
4427- float vmin = float (src_block->m );
4428-
4429- #pragma unroll(16)
4430- for (int j = 0 ; j < QK4_1/2 ; j++) {
4431- uint8_t q = src_block->qs [j];
4432- uint8_t q0 = q & 0x0F ;
4433- uint8_t q1 = (q >> 4 ) & 0x0F ;
4434- dst[j] = vmin + d * float (q0);
4435- dst[j + QK4_1/2 ] = vmin + d * float (q1);
4436- }
4437- }
4438-
4439- void dequant_q5_0_f (device const block_q5_0 * src_block, device float * dst) {
4440- float d = float (src_block->d );
4441- const float shift = 16 .f ;
4442-
4443- // Combine the four qh bytes into a 32-bit value.
4444- uint32_t qhVal = 0
4445- | ((uint32_t ) src_block->qh [0 ] << 0 )
4446- | ((uint32_t ) src_block->qh [1 ] << 8 )
4447- | ((uint32_t ) src_block->qh [2 ] << 16 )
4448- | ((uint32_t ) src_block->qh [3 ] << 24 );
4449-
4450- // First half
4451- #pragma unroll(16)
4452- for (int j = 0 ; j < QK5_0/2 ; j++) {
4453- uint8_t q = src_block->qs [j];
4454- uint8_t lowNib = q & 0x0F ;
4455- uint8_t highBit = (qhVal >> j) & 0x1 ;
4456- uint8_t qVal = (highBit << 4 ) | lowNib;
4457- dst[j] = (float (qVal) - shift) * d;
4458- }
4459- // Second half
4460- #pragma unroll(16)
4461- for (int j = QK5_0/2 ; j < QK5_0; j++) {
4462- int k = j - QK5_0/2 ;
4463- uint8_t q = src_block->qs [k];
4464- uint8_t hiNib = (q >> 4 ) & 0x0F ;
4465- uint8_t highBit = (qhVal >> j) & 0x1 ;
4466- uint8_t qVal = (highBit << 4 ) | hiNib;
4467- dst[j] = (float (qVal) - shift) * d;
4468- }
4469- }
4470-
4471- void dequant_q5_1_f (device const block_q5_1 * src_block, device float * dst) {
4472- float d = float (src_block->d );
4473- float vmin = float (src_block->m );
4474-
4475- uint32_t qhVal = 0
4476- | ((uint32_t ) src_block->qh [0 ] << 0 )
4477- | ((uint32_t ) src_block->qh [1 ] << 8 )
4478- | ((uint32_t ) src_block->qh [2 ] << 16 )
4479- | ((uint32_t ) src_block->qh [3 ] << 24 );
4480-
4481- // First half
4482- #pragma unroll(16)
4483- for (int j = 0 ; j < QK5_1/2 ; j++) {
4484- uint8_t q = src_block->qs [j];
4485- uint8_t lowNib = q & 0x0F ;
4486- uint8_t highBit = (qhVal >> j) & 0x1 ;
4487- uint8_t qVal = (highBit << 4 ) | lowNib;
4488- dst[j] = vmin + d * float (qVal);
4489- }
4490- // Second half
4491- #pragma unroll(16)
4492- for (int j = QK5_1/2 ; j < QK5_1; j++) {
4493- int k = j - QK5_1/2 ;
4494- uint8_t q = src_block->qs [k];
4495- uint8_t hiNib = (q >> 4 ) & 0x0F ;
4496- uint8_t highBit = (qhVal >> j) & 0x1 ;
4497- uint8_t qVal = (highBit << 4 ) | hiNib;
4498- dst[j] = vmin + d * float (qVal);
4499- }
4500- }
4501-
4502- void dequant_q8_0_f (device const block_q8_0 * src_block, device float * dst) {
4503- const float d = (float )src_block->d ;
4504-
4505- #pragma unroll(32)
4506- for (int j = 0 ; j < QK8_0; j++) {
4507- dst[j] = src_block->qs [j] * d;
4508- }
4509- }
4510-
4511- typedef decltype (kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>) cpy_q_t;
4512-
4513- template [[host_name(" kernel_cpy_q4_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>;
4514- template [[host_name(" kernel_cpy_q4_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_1, QK4_1, dequant_q4_1_f>;
4515- template [[host_name(" kernel_cpy_q5_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_0, QK5_0, dequant_q5_0_f>;
4516- template [[host_name(" kernel_cpy_q5_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_1, QK5_1, dequant_q5_1_f>;
4517- template [[host_name(" kernel_cpy_q8_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q8_0, QK8_0, dequant_q8_0_f>;
4518-
45194418template <typename args_t >
45204419void kernel_mul_mv_q2_K_f32_impl (
45214420 args_t args,
0 commit comments