Skip to content

Commit 8c1b186

Browse files
committed
metal : minor Q4_0 optimization
1 parent 86ed72d commit 8c1b186

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

ggml/src/ggml-metal.metal

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,18 +1441,18 @@ kernel void kernel_group_norm(
14411441
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
14421442
float d = qb_curr->d;
14431443

1444-
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
1444+
float acc = -8.0f*sumy;
14451445

14461446
device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
14471447

1448-
for (int i = 0; i < 8; i += 2) {
1449-
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
1450-
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
1451-
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
1452-
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
1448+
for (short i = 0; i < 4; ++i) {
1449+
acc += yl[2*i + 0] * (qs[i] & 0x000F);
1450+
acc += yl[2*i + 1] * (qs[i] & 0x0F00);
1451+
acc += yl[2*i + 8] * (qs[i] & 0x00F0);
1452+
acc += yl[2*i + 9] * (qs[i] & 0xF000);
14531453
}
14541454

1455-
return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
1455+
return d * acc;
14561456
}
14571457

14581458
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1567,37 +1567,36 @@ void mul_vec_q_n_f32_impl(
15671567
float yl[16]; // src1 vector cache
15681568
float sumf[nr] = {0.f};
15691569

1570-
const short ix = (tiisg/2);
1571-
const short il = (tiisg%2)*8;
1570+
const short ix = (tiisg%16);
1571+
const short il = (tiisg/16)*8;
15721572

15731573
device const float * yb = y + ix*QK4_0 + il;
15741574

15751575
// each thread in a SIMD group deals with half a block.
15761576
for (int ib = ix; ib < nb; ib += nw/2) {
1577-
float sumy[2] = { 0.f, 0.f };
1577+
float sumy = 0.0f;
15781578

1579-
#pragma unroll
1579+
#pragma unroll(4)
15801580
for (int i = 0; i < 8; i += 2) {
1581-
sumy[0] += yb[i + 0] + yb[i + 1];
1581+
sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17];
1582+
15821583
yl[i + 0] = yb[i + 0];
15831584
yl[i + 1] = yb[i + 1]/256.f;
1584-
1585-
sumy[1] += yb[i + 16] + yb[i + 17];
15861585
yl[i + 8] = yb[i + 16]/16.f;
15871586
yl[i + 9] = yb[i + 17]/4096.f;
15881587
}
15891588

1590-
#pragma unroll
1591-
for (int row = 0; row < nr; row++) {
1592-
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
1589+
#pragma unroll(nr)
1590+
for (short row = 0; row < nr; row++) {
1591+
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
15931592
}
15941593

15951594
yb += QK4_0 * 16;
15961595
}
15971596

15981597
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
15991598

1600-
for (int row = 0; row < nr; ++row) {
1599+
for (short row = 0; row < nr; ++row) {
16011600
const float tot = simd_sum(sumf[row]);
16021601

16031602
if (tiisg == 0 && first_row + row < args.ne01) {

0 commit comments

Comments
 (0)