Skip to content

Commit 30ed6ac

Browse files
committed
metal : adjust constants
ggml-ci
1 parent 8f6f0d5 commit 30ed6ac

File tree

2 files changed

+27
-33
lines changed

2 files changed

+27
-33
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@
3535
#define N_R0_Q3_K 2
3636
#define N_SG_Q3_K 2
3737

38-
#define N_R0_Q4_K 4
38+
#define N_R0_Q4_K 2
3939
#define N_SG_Q4_K 2
4040

4141
#define N_R0_Q5_K 2
4242
#define N_SG_Q5_K 2
4343

44-
#define N_R0_Q6_K 1
44+
#define N_R0_Q6_K 2
4545
#define N_SG_Q6_K 2
4646

4747
#define N_R0_IQ1_S 4

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6251,9 +6251,9 @@ void kernel_mul_mv_q4_K_f32_impl(
62516251
ushort sgitg) {
62526252
const short NSG = FC_mul_mv_nsg;
62536253

6254-
const uint16_t kmask1 = 0x3f3f;
6255-
const uint16_t kmask2 = 0x0f0f;
6256-
const uint16_t kmask3 = 0xc0c0;
6254+
constexpr uint16_t kmask1 = 0x3f3f;
6255+
constexpr uint16_t kmask2 = 0x0f0f;
6256+
constexpr uint16_t kmask3 = 0xc0c0;
62576257

62586258
const short ix = tiisg/8; // 0...3
62596259
const short it = tiisg%8; // 0...7
@@ -6312,7 +6312,7 @@ void kernel_mul_mv_q4_K_f32_impl(
63126312
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
63136313
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
63146314

6315-
for (short i = 0; i < 4; ++i) {
6315+
FOR_UNROLL (short i = 0; i < 4; ++i) {
63166316
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
63176317
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
63186318
acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
@@ -6323,14 +6323,11 @@ void kernel_mul_mv_q4_K_f32_impl(
63236323
acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
63246324
}
63256325

6326-
float dall = dh[0];
6327-
float dmin = dh[1];
6328-
6329-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
6330-
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
6331-
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
6332-
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
6333-
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
6326+
sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
6327+
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
6328+
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
6329+
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
6330+
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
63346331

63356332
q1 += args.nb01/2;
63366333
sc += args.nb01/2;
@@ -6396,9 +6393,9 @@ void kernel_mul_mv_q5_K_f32_impl(
63966393

63976394
float yl[16], yh[16];
63986395

6399-
const uint16_t kmask1 = 0x3f3f;
6400-
const uint16_t kmask2 = 0x0f0f;
6401-
const uint16_t kmask3 = 0xc0c0;
6396+
constexpr uint16_t kmask1 = 0x3f3f;
6397+
constexpr uint16_t kmask2 = 0x0f0f;
6398+
constexpr uint16_t kmask3 = 0xc0c0;
64026399

64036400
const short tid = tiisg/4;
64046401
const short ix = tiisg%4;
@@ -6444,7 +6441,7 @@ void kernel_mul_mv_q5_K_f32_impl(
64446441

64456442
float4 acc1 = {0.f};
64466443
float4 acc2 = {0.f};
6447-
for (short l = 0; l < 8; ++l) {
6444+
FOR_UNROLL (short l = 0; l < 8; ++l) {
64486445
uint8_t h = qh[l];
64496446
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
64506447
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -6455,13 +6452,12 @@ void kernel_mul_mv_q5_K_f32_impl(
64556452
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
64566453
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
64576454
}
6458-
const float dall = dh[0];
6459-
const float dmin = dh[1];
6460-
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
6461-
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
6462-
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
6463-
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
6464-
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
6455+
6456+
sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
6457+
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
6458+
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
6459+
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
6460+
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
64656461

64666462
q1 += args.nb01;
64676463
qh += args.nb01;
@@ -6507,10 +6503,10 @@ void kernel_mul_mv_q6_K_f32_impl(
65076503
ushort sgitg) {
65086504
const short NSG = FC_mul_mv_nsg;
65096505

6510-
const uint8_t kmask1 = 0x03;
6511-
const uint8_t kmask2 = 0x0C;
6512-
const uint8_t kmask3 = 0x30;
6513-
const uint8_t kmask4 = 0xC0;
6506+
constexpr uint8_t kmask1 = 0x03;
6507+
constexpr uint8_t kmask2 = 0x0C;
6508+
constexpr uint8_t kmask3 = 0x30;
6509+
constexpr uint8_t kmask4 = 0xC0;
65146510

65156511
const int nb = args.ne00/QK_K;
65166512

@@ -6561,18 +6557,16 @@ void kernel_mul_mv_q6_K_f32_impl(
65616557
}
65626558

65636559
for (short row = 0; row < nr0; ++row) {
6564-
const float dall = dh[0];
6565-
65666560
float4 sums = {0.f, 0.f, 0.f, 0.f};
65676561

6568-
for (short l = 0; l < 4; ++l) {
6562+
FOR_UNROLL (short l = 0; l < 4; ++l) {
65696563
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
65706564
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
65716565
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
65726566
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
65736567
}
65746568

6575-
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
6569+
sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
65766570

65776571
q1 += args.nb01;
65786572
q2 += args.nb01;

0 commit comments

Comments
 (0)