Skip to content

Commit 49c1ac0

Browse files
committed
metal : support tensors in mul_mm_id
1 parent b99e72d commit 49c1ac0

File tree

1 file changed

+195
-63
lines changed

1 file changed

+195
-63
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 195 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8204,12 +8204,12 @@ kernel void kernel_mul_mm(
82048204
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
82058205
}
82068206
#else
8207-
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8208-
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8207+
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8208+
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
82098209

8210-
constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);
8211-
8212-
mpp::tensor_ops::matmul2d<desc, execution_simdgroups<4>> mm;
8210+
mpp::tensor_ops::matmul2d<
8211+
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8212+
execution_simdgroups<4>> mm;
82138213

82148214
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
82158215
#endif
@@ -8522,72 +8522,169 @@ kernel void kernel_mul_mm_id(
85228522
ushort tiitg[[thread_index_in_threadgroup]],
85238523
ushort tiisg[[thread_index_in_simdgroup]],
85248524
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8525-
85268525
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
85278526
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
85288527

8529-
const int r0 = tgpig.y;
8530-
const int r1 = tgpig.x;
8528+
threadgroup float * sc = (threadgroup float *)(shmem);
8529+
8530+
constexpr int NR0 = 64;
8531+
constexpr int NR1 = 32;
8532+
8533+
constexpr int NK = 32;
8534+
constexpr int NL0 = NK/16;
8535+
constexpr int NL1 = NK/8;
8536+
85318537
const int im = tgpig.z; // expert
8538+
const int r0 = tgpig.y*NR0;
8539+
const int r1 = tgpig.x*NR1;
85328540

85338541
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
85348542
device const int32_t * ids_i32 = (device const int32_t *) (hids);
85358543

85368544
const int32_t neh1 = tpe_u32[im];
85378545

8538-
if (r1*BLOCK_SIZE_N >= neh1) {
8546+
if (r1 >= neh1) {
85398547
return;
85408548
}
85418549

85428550
// if this block is of 64x32 shape or smaller
8543-
const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
8544-
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
8551+
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
8552+
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
85458553

85468554
// a thread shouldn't load data outside of the matrix
8547-
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
8548-
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
8549-
8550-
S0_8x8 ma[4];
8551-
S1_8x8 mb[2];
8552-
8553-
simdgroup_float8x8 mc[8];
8555+
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
8556+
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
85548557

8555-
for (short i = 0; i < 8; i++){
8556-
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8557-
}
8558+
const short il0 = (tiitg % NL0);
85588559

8559-
short il = (tiitg % THREAD_PER_ROW);
8560+
short il = il0;
85608561

8561-
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
8562+
const int id = ids_i32[im*args.ne21 + r1 + lr1];
85628563

85638564
const short i11 = (id % args.ne20) % args.ne11;
85648565
const short i12 = (id / args.ne20);
85658566
const short i13 = 0;
85668567

85678568
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
8568-
const short offset1 = il/nl;
8569+
const short offset1 = il0/nl;
85698570

8570-
device const block_q * x = (device const block_q *)(src0
8571-
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
8571+
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
85728572

8573-
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
8573+
const short iy = 8*(tiitg % NL1);
85748574

85758575
device const T1 * y = (device const T1 *)(src1
85768576
+ args.nb13*i13
85778577
+ args.nb12*i12
85788578
+ args.nb11*i11
85798579
+ args.nb10*iy);
85808580

8581-
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
8581+
#ifndef GGML_METAL_HAS_TENSOR
8582+
S0_8x8 ma[4];
8583+
S1_8x8 mb[2];
8584+
8585+
simdgroup_float8x8 mc[8];
8586+
8587+
for (short i = 0; i < 8; i++){
8588+
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8589+
}
8590+
#else
8591+
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8592+
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8593+
8594+
mpp::tensor_ops::matmul2d<
8595+
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8596+
execution_simdgroups<4>> mm;
8597+
8598+
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8599+
#endif
8600+
8601+
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8602+
#ifndef GGML_METAL_HAS_TENSOR
8603+
// load data and store to threadgroup memory
8604+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8605+
threadgroup_barrier(mem_flags::mem_threadgroup);
8606+
8607+
// no need for dequantization
8608+
for (short i = 0; i < 16; i++) {
8609+
const short sx = 2*il0 + i/8;
8610+
const short sy = (tiitg/NL0)/8;
8611+
8612+
//const short lx = i%8;
8613+
//const short ly = (tiitg/NL0)%8;
8614+
const short lx = (tiitg/NL0)%8;
8615+
const short ly = i%8;
8616+
8617+
const short ib = 8*sx + sy;
8618+
8619+
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8620+
}
8621+
} else {
8622+
S0_4x4 temp_a;
8623+
dequantize_func(x, il, temp_a);
8624+
8625+
threadgroup_barrier(mem_flags::mem_threadgroup);
8626+
8627+
FOR_UNROLL (short i = 0; i < 16; i++) {
8628+
const short sx = 2*il0 + i/8;
8629+
const short sy = (tiitg/NL0)/8;
8630+
8631+
//const short lx = i%8;
8632+
//const short ly = (tiitg/NL0)%8;
8633+
const short lx = (tiitg/NL0)%8;
8634+
const short ly = i%8;
8635+
8636+
const short ib = 8*sx + sy;
8637+
8638+
// NOTE: this is massively slower.. WTF?
8639+
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
8640+
8641+
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
8642+
}
8643+
}
8644+
8645+
if (FC_mul_mm_bc_inp) {
8646+
for (short i = 0; i < 8; ++i) {
8647+
const short sx = (tiitg%NL1);
8648+
const short sy = (tiitg/NL1)/8;
8649+
8650+
const short lx = i;
8651+
const short ly = (tiitg/NL1)%8;
8652+
//const short lx = (tiitg/NL1)%8;
8653+
//const short ly = i;
8654+
8655+
const short ib = 4*sx + sy;
8656+
8657+
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8658+
}
8659+
} else {
8660+
const short sx = (tiitg%NL1);
8661+
const short sy = (tiitg/NL1)/8;
8662+
8663+
const short dx = sx;
8664+
const short dy = sy;
8665+
8666+
const short ly = (tiitg/NL1)%8;
8667+
8668+
const short ib = 4*sx + sy;
8669+
8670+
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
8671+
}
8672+
#else
85828673
// load data and store to threadgroup memory
85838674
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
85848675
threadgroup_barrier(mem_flags::mem_threadgroup);
85858676

85868677
// no need for dequantization
85878678
for (short i = 0; i < 16; i++) {
8588-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8589-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8590-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
8679+
const short sx = 2*il0 + i/8;
8680+
const short sy = (tiitg/NL0)/8;
8681+
8682+
const short lx = i%8;
8683+
const short ly = (tiitg/NL0)%8;
8684+
//const short lx = (tiitg/NL0)%8;
8685+
//const short ly = i%8;
8686+
8687+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
85918688
}
85928689
} else {
85938690
S0_4x4 temp_a;
@@ -8596,85 +8693,120 @@ kernel void kernel_mul_mm_id(
85968693
threadgroup_barrier(mem_flags::mem_threadgroup);
85978694

85988695
FOR_UNROLL (short i = 0; i < 16; i++) {
8599-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8600-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8601-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
8696+
const short sx = 2*il0 + i/8;
8697+
const short sy = (tiitg/NL0)/8;
8698+
8699+
const short lx = i%8;
8700+
const short ly = (tiitg/NL0)%8;
8701+
//const short lx = (tiitg/NL0)%8;
8702+
//const short ly = i%8;
8703+
8704+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
86028705
}
86038706
}
86048707

86058708
if (FC_mul_mm_bc_inp) {
86068709
for (short i = 0; i < 8; ++i) {
8607-
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
8710+
const short sx = (tiitg%NL1);
8711+
const short sy = (tiitg/NL1)/8;
8712+
8713+
const short lx = i;
8714+
const short ly = (tiitg/NL1)%8;
8715+
//const short lx = (tiitg/NL1)%8;
8716+
//const short ly = i;
8717+
8718+
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
86088719
}
86098720
} else {
8610-
*(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
8721+
const short sx = (tiitg%NL1);
8722+
const short sy = (tiitg/NL1)/8;
8723+
8724+
//const short lx = i;
8725+
const short ly = (tiitg/NL1)%8;
8726+
//const short lx = (tiitg/NL1)%8;
8727+
//const short ly = i;
8728+
8729+
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
86118730
}
8731+
#endif
86128732

86138733
il = (il + 2 < nl) ? il + 2 : il % 2;
86148734
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
8615-
y += BLOCK_SIZE_K;
8735+
8736+
y += NK;
86168737

86178738
threadgroup_barrier(mem_flags::mem_threadgroup);
86188739

8740+
#ifndef GGML_METAL_HAS_TENSOR
86198741
// load matrices from threadgroup memory and conduct outer products
8620-
threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
8621-
threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
8622-
8623-
#pragma unroll(4)
8624-
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
8625-
#pragma unroll(4)
8626-
for (short i = 0; i < 4; i++) {
8627-
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
8742+
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
8743+
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
8744+
8745+
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
8746+
simdgroup_barrier(mem_flags::mem_none);
8747+
8748+
FOR_UNROLL (short i = 0; i < 4; i++) {
8749+
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
86288750
}
86298751

86308752
simdgroup_barrier(mem_flags::mem_none);
86318753

8632-
#pragma unroll(2)
8633-
for (short i = 0; i < 2; i++) {
8634-
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
8754+
FOR_UNROLL (short i = 0; i < 2; i++) {
8755+
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
86358756
}
86368757

8637-
#pragma unroll(8)
8638-
for (short i = 0; i < 8; i++){
8758+
simdgroup_barrier(mem_flags::mem_none);
8759+
8760+
FOR_UNROLL (short i = 0; i < 8; i++){
86398761
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
86408762
}
86418763

8642-
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
8643-
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
8764+
lsma += 8*64;
8765+
lsmb += 4*64;
86448766
}
8767+
#else
8768+
auto sA = tA.slice(0, 0);
8769+
auto sB = tB.slice(0, 0);
8770+
8771+
mm.run(sB, sA, cT);
8772+
#endif
86458773
}
86468774

8775+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
86478776
threadgroup_barrier(mem_flags::mem_threadgroup);
86488777

8649-
threadgroup float * temp_str = ((threadgroup float *) shmem) \
8650-
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
8778+
#ifdef GGML_METAL_HAS_TENSOR
8779+
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
8780+
cT.store(tC);
8781+
#else
8782+
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
86518783

8652-
#pragma unroll(8)
86538784
for (short i = 0; i < 8; i++) {
8654-
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
8785+
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
86558786
}
8787+
#endif
86568788

86578789
threadgroup_barrier(mem_flags::mem_threadgroup);
86588790

8659-
for (short j = sgitg; j < n_cols; j += 4) {
8660-
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
8791+
for (short j = sgitg; j < nr1; j += 4) {
8792+
const int id = ids_i32[im*args.ne21 + r1 + j];
86618793

86628794
const short ide = id % args.ne20;
86638795
const short idt = id / args.ne20;
86648796

8665-
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
8797+
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
86668798
device float4 * D4 = (device float4 *) D;
86678799

8668-
threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
8800+
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
86698801
threadgroup float4 * C4 = (threadgroup float4 *) C;
86708802

86718803
int i = tiisg;
8672-
for (; i < n_rows/4; i += 32) {
8804+
for (; i < nr0/4; i += 32) {
86738805
*(D4 + i) = *(C4 + i);
86748806
}
86758807

8676-
i = (4*(n_rows/4)) + tiisg;
8677-
for (; i < n_rows; i += 32) {
8808+
i = (4*(nr0/4)) + tiisg;
8809+
for (; i < nr0; i += 32) {
86788810
*(D + i) = *(C + i);
86798811
}
86808812
}

0 commit comments

Comments
 (0)