Skip to content

Commit 7841653

Browse files
committed
metal : minor refactor
1 parent bb9d36b commit 7841653

File tree

1 file changed

+38
-33
lines changed

1 file changed

+38
-33
lines changed

ggml/src/ggml-metal.metal

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,17 +1064,18 @@ kernel void kernel_group_norm(
10641064
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
10651065
float d = qb_curr->d;
10661066

1067-
float2 acc = 0.f;
1067+
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
10681068

1069-
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
1069+
device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
10701070

1071-
for (int i = 0; i < 8; i+=2) {
1072-
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
1073-
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
1074-
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
1075-
+ yl[i + 9] * (qs[i / 2] & 0xF000);
1071+
for (int i = 0; i < 8; i += 2) {
1072+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
1073+
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
1074+
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
1075+
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
10761076
}
1077-
return d * (sumy * -8.f + acc[0] + acc[1]);
1077+
1078+
return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
10781079
}
10791080

10801081
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
10851086
float d = qb_curr->d;
10861087
float m = qb_curr->m;
10871088

1088-
float2 acc = 0.f;
1089+
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
10891090

1090-
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
1091+
device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
10911092

10921093
for (int i = 0; i < 8; i+=2) {
1093-
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
1094-
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
1095-
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
1096-
+ yl[i + 9] * (qs[i / 2] & 0xF000);
1094+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
1095+
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
1096+
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
1097+
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
10971098
}
1098-
return d * (acc[0] + acc[1]) + sumy * m;
1099+
1100+
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
10991101
}
11001102

11011103
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
11051107
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
11061108
float d = qb_curr->d;
11071109

1108-
float2 acc = 0.f;
1110+
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
11091111

11101112
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
11111113
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
11121114

11131115
for (int i = 0; i < 8; i+=2) {
1114-
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
1115-
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
1116-
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
1117-
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
1116+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
1117+
acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
1118+
acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
1119+
acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
11181120
}
1119-
return d * (sumy * -16.f + acc[0] + acc[1]);
1121+
1122+
return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
11201123
}
11211124

11221125
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
11271130
float d = qb_curr->d;
11281131
float m = qb_curr->m;
11291132

1130-
float2 acc = 0.f;
1133+
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
11311134

11321135
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
11331136
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
11341137

11351138
for (int i = 0; i < 8; i+=2) {
1136-
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
1137-
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
1138-
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
1139-
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
1139+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
1140+
acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
1141+
acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
1142+
acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
11401143
}
1141-
return d * (acc[0] + acc[1]) + sumy * m;
1144+
1145+
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
11421146
}
11431147

11441148
// putting them in the kernel cause a significant performance penalty
@@ -1208,14 +1212,15 @@ void mul_vec_q_n_f32_impl(
12081212
// each thread in a SIMD group deals with half a block.
12091213
for (int ib = ix; ib < nb; ib += nw/2) {
12101214
float sumy = 0;
1215+
12111216
for (int i = 0; i < 8; i += 2) {
1212-
sumy += yb[i] + yb[i+1];
1213-
yl[i+0] = yb[i+ 0];
1214-
yl[i+1] = yb[i+ 1]/256.f;
1217+
sumy += yb[i + 0] + yb[i + 1];
1218+
yl[i + 0] = yb[i + 0];
1219+
yl[i + 1] = yb[i + 1]/256.f;
12151220

1216-
sumy += yb[i+16] + yb[i+17];
1217-
yl[i+8] = yb[i+16]/16.f;
1218-
yl[i+9] = yb[i+17]/4096.f;
1221+
sumy += yb[i + 16] + yb[i + 17];
1222+
yl[i + 8] = yb[i + 16]/16.f;
1223+
yl[i + 9] = yb[i + 17]/4096.f;
12191224
}
12201225

12211226
for (int row = 0; row < nr; row++) {

0 commit comments

Comments
 (0)