Skip to content

Commit 5f0d371

Browse files
committed
4x4 -> 4x
1 parent 59b33b9 commit 5f0d371

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

ggml/src/ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1943,7 +1943,7 @@ static void ggml_metal_encode_node(
19431943
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
19441944

19451945
const int nsg = 2;
1946-
const int r0pt = 2;
1946+
const int r0pt = 1;
19471947
const int r1pt = 1;
19481948
const int nxpsg = ne11 > 1 ? 8 : 32;
19491949
const int nypsg = 32/nxpsg;

ggml/src/ggml-metal.metal

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,26 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
165165
reg = (type4x4) reg_f;
166166
}
167167

168+
template <typename type4>
169+
void dequantize_q4_0x(device const block_q4_0 *xb, short il, thread type4 & reg) {
170+
device const int8_t * qs = ((device const int8_t *)xb->qs);
171+
const half d = xb->d;
172+
173+
for (int i = 0; i < 4; i++) {
174+
reg[i] = qs[0];
175+
}
176+
}
177+
178+
template <typename type4>
179+
void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg) {
180+
device const int8_t * qs = ((device const int8_t *)xb->qs);
181+
const half d = xb->d;
182+
183+
for (int i = 0; i < 4; i++) {
184+
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
185+
}
186+
}
187+
168188
template <typename type4x4>
169189
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
170190
const float d = xb->d;
@@ -1749,8 +1769,8 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17491769
ushort3 ntg[[threads_per_threadgroup]],
17501770
ushort tiisg[[thread_index_in_simdgroup]],
17511771
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1752-
const short chpt = 1;
1753-
const short r0pt = 2;
1772+
const short chpt = 4;
1773+
const short r0pt = 1;
17541774

17551775
//const short nxpsg = (32);
17561776
const short nypsg = (32/nxpsg)*r0pt;
@@ -1771,36 +1791,31 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17711791
device const block_q8_0 * xq[r0pt];
17721792

17731793
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1774-
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/2 : (device const block_q8_0 *) src0;
1794+
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
1795+
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
17751796
}
17761797

1777-
device const float4x4 * y4x4 = (device const float4x4 *) (src1 + offset1) + chpt*tx;
1798+
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
1799+
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
17781800

17791801
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
17801802

1781-
for (int iib = 0; (16*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
1782-
float4x4 lx;
1783-
1784-
#pragma unroll(2)
1803+
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
17851804
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1786-
#pragma unroll
1805+
#pragma unroll(4)
17871806
for (short ch = 0; ch < chpt; ++ch) {
1788-
dequantize_q8_0(xq[ir0] + ch/2, (chpt*tx + ch)%2, lx);
1807+
float4 lx;
17891808

1790-
const float4x4 ly = y4x4[ch];
1809+
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
17911810

1792-
sumf[ir0] +=
1793-
dot(lx[0], ly[0]) +
1794-
dot(lx[1], ly[1]) +
1795-
dot(lx[2], ly[2]) +
1796-
dot(lx[3], ly[3]);
1811+
sumf[ir0] += dot(lx, y4[ch*nxpsg]);
17971812
}
17981813
}
17991814

1800-
y4x4 += ((16*chpt)*nxpsg)/16;
1815+
y4 += ((4*chpt)*nxpsg)/4;
18011816

18021817
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1803-
xq[ir0] += ((16*chpt)*nxpsg)/32;
1818+
xq[ir0] += ((4*chpt)*nxpsg)/32;
18041819
}
18051820
}
18061821

0 commit comments

Comments
 (0)