Skip to content

Commit fe12e20

Browse files
committed
metal : mv q6_K support nr0 > 1
ggml-ci
1 parent 51dea76 commit fe12e20

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4975,9 +4975,9 @@ void kernel_mul_mv_q6_K_f32_impl(
49754975
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
49764976
device const float * yy = (device const float *) (src1 + offset1);
49774977

4978-
// TODO: support nr0 > 1
4979-
static_assert(nr0 == 1, "nr0 > 1 not supported");
4980-
float sumf[1] = { 0.f };
4978+
float sumf[nr0] = { 0.f };
4979+
4980+
float yl[16];
49814981

49824982
const short tid = tiisg/2;
49834983
const short ix = tiisg%2;
@@ -4995,22 +4995,37 @@ void kernel_mul_mv_q6_K_f32_impl(
49954995
device const uint8_t * q2 = q1 + 32;
49964996
device const uint8_t * qh = x[i].qh + q_offset_h;
49974997
device const int8_t * sc = x[i].scales + is;
4998+
device const half * dh = &x[i].d;
49984999

49995000
device const float * y = yy + i * QK_K + y_offset;
50005001

5001-
const float dall = x[i].d;
5002-
5003-
float4 sums = {0.f, 0.f, 0.f, 0.f};
5004-
5005-
#pragma unroll(4)
50065002
for (short l = 0; l < 4; ++l) {
5007-
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
5008-
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
5009-
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
5010-
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
5003+
yl[4*l + 0] = y[l + 0];
5004+
yl[4*l + 1] = y[l + 32];
5005+
yl[4*l + 2] = y[l + 64];
5006+
yl[4*l + 3] = y[l + 96];
50115007
}
50125008

5013-
sumf[0] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
5009+
for (short row = 0; row < nr0; ++row) {
5010+
const float dall = dh[0];
5011+
5012+
float4 sums = {0.f, 0.f, 0.f, 0.f};
5013+
5014+
for (short l = 0; l < 4; ++l) {
5015+
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
5016+
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
5017+
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
5018+
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
5019+
}
5020+
5021+
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
5022+
5023+
q1 += args.nb01;
5024+
q2 += args.nb01;
5025+
qh += args.nb01;
5026+
sc += args.nb01;
5027+
dh += args.nb01/2;
5028+
}
50145029
}
50155030

50165031
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;

0 commit comments

Comments
 (0)