Skip to content

Commit 83a7499

Browse files
committed
metal : rework mat-mat multiplication
1 parent 280d97b commit 83a7499

File tree

2 files changed

+98
-51
lines changed

2 files changed

+98
-51
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 96 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8145,17 +8145,24 @@ kernel void kernel_mul_mm(
81458145
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
81468146
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
81478147

8148-
const int r0 = tgpig.y;
8149-
const int r1 = tgpig.x;
8148+
constexpr int NR0 = 64;
8149+
constexpr int NR1 = 32;
8150+
8151+
constexpr int NK = 32;
8152+
constexpr int NL0 = NK/16;
8153+
constexpr int NL1 = NK/8;
8154+
81508155
const int im = tgpig.z;
8156+
const int r0 = tgpig.y*NR0;
8157+
const int r1 = tgpig.x*NR1;
81518158

81528159
// if this block is of 64x32 shape or smaller
8153-
const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
8154-
const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
8160+
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
8161+
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
81558162

81568163
// a thread shouldn't load data outside of the matrix
8157-
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
8158-
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
8164+
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
8165+
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
81598166

81608167
S0_8x8 ma[4];
81618168
S1_8x8 mb[2];
@@ -8166,35 +8173,44 @@ kernel void kernel_mul_mm(
81668173
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
81678174
}
81688175

8169-
short il = (tiitg % THREAD_PER_ROW);
8176+
const short il0 = (tiitg % NL0);
8177+
8178+
short il = il0;
81708179

81718180
const int i12 = im%args.ne12;
81728181
const int i13 = im/args.ne12;
81738182

81748183
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8175-
const short offset1 = il/nl;
8184+
const short offset1 = il0/nl;
81768185

8177-
device const block_q * x = (device const block_q *)(src0
8178-
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
8186+
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
81798187

8180-
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
8188+
const short iy = 8*(tiitg % NL1);
81818189

81828190
device const T1 * y = (device const T1 *)(src1
81838191
+ args.nb13*i13
81848192
+ args.nb12*i12
8185-
+ args.nb11*(r1*BLOCK_SIZE_N + thread_col)
8193+
+ args.nb11*(r1 + lr1)
81868194
+ args.nb10*iy);
81878195

8188-
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
8196+
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
81898197
// load data and store to threadgroup memory
81908198
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
81918199
threadgroup_barrier(mem_flags::mem_threadgroup);
81928200

81938201
// no need for dequantization
81948202
for (short i = 0; i < 16; i++) {
8195-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8196-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8197-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
8203+
const short sx = 2*il0 + i/8;
8204+
const short sy = (tiitg/NL0)/8;
8205+
8206+
//const short lx = i%8;
8207+
//const short ly = (tiitg/NL0)%8;
8208+
const short lx = (tiitg/NL0)%8;
8209+
const short ly = i%8;
8210+
8211+
const short ib = 8*sx + sy;
8212+
8213+
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
81988214
}
81998215
} else {
82008216
S0_4x4 temp_a;
@@ -8203,91 +8219,122 @@ kernel void kernel_mul_mm(
82038219
threadgroup_barrier(mem_flags::mem_threadgroup);
82048220

82058221
FOR_UNROLL (short i = 0; i < 16; i++) {
8206-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8207-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8208-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
8222+
const short sx = 2*il0 + i/8;
8223+
const short sy = (tiitg/NL0)/8;
8224+
8225+
//const short lx = i%8;
8226+
//const short ly = (tiitg/NL0)%8;
8227+
const short lx = (tiitg/NL0)%8;
8228+
const short ly = i%8;
8229+
8230+
const short ib = 8*sx + sy;
8231+
8232+
// NOTE: this is massively slower.. WTF?
8233+
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
8234+
8235+
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
82098236
}
82108237
}
82118238

82128239
if (FC_mul_mm_bc_inp) {
82138240
for (short i = 0; i < 8; ++i) {
8214-
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
8241+
const short sx = (tiitg%NL1);
8242+
const short sy = (tiitg/NL1)/8;
8243+
8244+
const short lx = i;
8245+
const short ly = (tiitg/NL1)%8;
8246+
//const short lx = (tiitg/NL1)%8;
8247+
//const short ly = i;
8248+
8249+
const short ib = 4*sx + sy;
8250+
8251+
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
82158252
}
82168253
} else {
8217-
*(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
8254+
const short sx = (tiitg%NL1);
8255+
const short sy = (tiitg/NL1)/8;
8256+
8257+
const short dx = sx;
8258+
const short dy = sy;
8259+
8260+
const short ly = (tiitg/NL1)%8;
8261+
8262+
const short ib = 4*sx + sy;
8263+
8264+
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
82188265
}
82198266

82208267
il = (il + 2 < nl) ? il + 2 : il % 2;
82218268
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
8222-
y += BLOCK_SIZE_K;
82238269

8224-
threadgroup_barrier(mem_flags::mem_threadgroup);
8270+
y += NK;
82258271

82268272
// load matrices from threadgroup memory and conduct outer products
8227-
threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
8228-
threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
8273+
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
8274+
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
82298275

8230-
#pragma unroll(4)
8231-
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
8276+
threadgroup_barrier(mem_flags::mem_threadgroup);
8277+
8278+
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
82328279
simdgroup_barrier(mem_flags::mem_none);
82338280

8234-
#pragma unroll(4)
8235-
for (short i = 0; i < 4; i++) {
8236-
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
8281+
FOR_UNROLL (short i = 0; i < 4; i++) {
8282+
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
82378283
}
82388284

8239-
#pragma unroll(2)
8240-
for (short i = 0; i < 2; i++) {
8241-
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
8285+
simdgroup_barrier(mem_flags::mem_none);
8286+
8287+
FOR_UNROLL (short i = 0; i < 2; i++) {
8288+
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
82428289
}
82438290

82448291
simdgroup_barrier(mem_flags::mem_none);
82458292

8246-
#pragma unroll(8)
8247-
for (short i = 0; i < 8; i++){
8293+
FOR_UNROLL (short i = 0; i < 8; i++){
82488294
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
82498295
}
82508296

8251-
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
8252-
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
8297+
lsma += 8*64;
8298+
lsmb += 4*64;
82538299
}
82548300
}
82558301

8256-
if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) {
8302+
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
82578303
// if no bounds checks on the output are needed, we can directly write to device memory
82588304
device float * C = (device float *) dst +
8259-
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
8260-
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
8305+
(r0 + 32*(sgitg & 1)) + \
8306+
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
82618307

82628308
for (short i = 0; i < 8; i++) {
8263-
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
8309+
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0, 0, false);
82648310
}
82658311
} else {
82668312
// block is smaller than 64x32, we should avoid writing data outside of the matrix
82678313
threadgroup_barrier(mem_flags::mem_threadgroup);
8268-
threadgroup float * temp_str = ((threadgroup float *) shmem) \
8269-
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
8314+
8315+
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
8316+
82708317
for (short i = 0; i < 8; i++) {
8271-
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
8318+
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
82728319
}
82738320

82748321
threadgroup_barrier(mem_flags::mem_threadgroup);
82758322

82768323
if (sgitg == 0) {
8277-
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
8278-
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
8324+
for (int j = tiitg; j < nr1; j += NR1) {
8325+
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
82798326
device float4 * D4 = (device float4 *) D;
82808327

8281-
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
8328+
threadgroup float * C = temp_str + (j*NR0);
82828329
threadgroup float4 * C4 = (threadgroup float4 *) C;
82838330

82848331
int i = 0;
8285-
for (; i < n_rows/4; i++) {
8332+
for (; i < nr0/4; i++) {
82868333
*(D4 + i) = *(C4 + i);
82878334
}
82888335

82898336
i *= 4;
8290-
for (; i < n_rows; i++) {
8337+
for (; i < nr0; i++) {
82918338
*(D + i) = *(C + i);
82928339
}
82938340
}

tests/test-backend-ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,7 +1284,7 @@ struct test_case {
12841284
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
12851285
//}
12861286
//printf("\n");
1287-
//exit(1);
1287+
exit(1);
12881288
ud->ok = false;
12891289
}
12901290
return true;
@@ -6761,7 +6761,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67616761
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3));
67626762
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
67636763

6764-
#if 0
6764+
#if 1
67656765
// test the mat-mat path for Metal
67666766
for (int k = 1; k < 512; ++k) {
67676767
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1}));

0 commit comments

Comments
 (0)