Skip to content

Commit 5e09948

Browse files
committed
metal : initial Metal4 support
1 parent 83a7499 commit 5e09948

File tree

1 file changed

+110
-10
lines changed

1 file changed

+110
-10
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 110 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ __embed_ggml-common.h__
99

1010
#include <metal_stdlib>
1111

12+
#define GGML_METAL_USE_METAL4
13+
14+
#ifdef GGML_METAL_USE_METAL4
15+
#include <metal_stdlib>
16+
#include <metal_tensor>
17+
18+
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
19+
20+
using namespace metal;
21+
using namespace mpp::tensor_ops;
22+
#endif
23+
1224
using namespace metal;
1325

1426
#define MAX(x, y) ((x) > (y) ? (x) : (y))
@@ -8145,6 +8157,8 @@ kernel void kernel_mul_mm(
81458157
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
81468158
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
81478159

8160+
threadgroup float * sc = (threadgroup float *)(shmem);
8161+
81488162
constexpr int NR0 = 64;
81498163
constexpr int NR1 = 32;
81508164

@@ -8164,15 +8178,6 @@ kernel void kernel_mul_mm(
81648178
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
81658179
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
81668180

8167-
S0_8x8 ma[4];
8168-
S1_8x8 mb[2];
8169-
8170-
simdgroup_float8x8 mc[8];
8171-
8172-
for (short i = 0; i < 8; i++){
8173-
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8174-
}
8175-
81768181
const short il0 = (tiitg % NL0);
81778182

81788183
short il = il0;
@@ -8193,7 +8198,28 @@ kernel void kernel_mul_mm(
81938198
+ args.nb11*(r1 + lr1)
81948199
+ args.nb10*iy);
81958200

8201+
#ifndef GGML_METAL_USE_METAL4
8202+
S0_8x8 ma[4];
8203+
S1_8x8 mb[2];
8204+
8205+
simdgroup_float8x8 mc[8];
8206+
8207+
for (short i = 0; i < 8; i++){
8208+
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8209+
}
8210+
#else
8211+
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8212+
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8213+
8214+
constexpr auto desc = matmul2d_descriptor(NR1, NR0, NK, false, true, false, matmul2d_descriptor::mode::multiply_accumulate);
8215+
8216+
matmul2d<desc, execution_simdgroups<4>> mm;
8217+
8218+
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8219+
#endif
8220+
81968221
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8222+
#ifndef GGML_METAL_USE_METAL4
81978223
// load data and store to threadgroup memory
81988224
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
81998225
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -8297,26 +8323,100 @@ kernel void kernel_mul_mm(
82978323
lsma += 8*64;
82988324
lsmb += 4*64;
82998325
}
8326+
#else
8327+
// load data and store to threadgroup memory
8328+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8329+
threadgroup_barrier(mem_flags::mem_threadgroup);
8330+
8331+
// no need for dequantization
8332+
for (short i = 0; i < 16; i++) {
8333+
const short sx = 2*il0 + i/8;
8334+
const short sy = (tiitg/NL0)/8;
8335+
8336+
const short lx = i%8;
8337+
const short ly = (tiitg/NL0)%8;
8338+
//const short lx = (tiitg/NL0)%8;
8339+
//const short ly = i%8;
8340+
8341+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8342+
}
8343+
} else {
8344+
S0_4x4 temp_a;
8345+
dequantize_func(x, il, temp_a);
8346+
8347+
threadgroup_barrier(mem_flags::mem_threadgroup);
8348+
8349+
FOR_UNROLL (short i = 0; i < 16; i++) {
8350+
const short sx = 2*il0 + i/8;
8351+
const short sy = (tiitg/NL0)/8;
8352+
8353+
const short lx = i%8;
8354+
const short ly = (tiitg/NL0)%8;
8355+
//const short lx = (tiitg/NL0)%8;
8356+
//const short ly = i%8;
8357+
8358+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
8359+
}
8360+
}
8361+
8362+
for (short i = 0; i < 8; ++i) {
8363+
const short sx = (tiitg%NL1);
8364+
const short sy = (tiitg/NL1)/8;
8365+
8366+
const short lx = i;
8367+
const short ly = (tiitg/NL1)%8;
8368+
//const short lx = (tiitg/NL1)%8;
8369+
//const short ly = i;
8370+
8371+
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8372+
}
8373+
8374+
il = (il + 2 < nl) ? il + 2 : il % 2;
8375+
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
8376+
8377+
y += NK;
8378+
8379+
threadgroup_barrier(mem_flags::mem_threadgroup);
8380+
8381+
auto sA = tA.slice(0, 0);
8382+
auto sB = tB.slice(0, 0);
8383+
8384+
mm.run(sB, sA, cT);
8385+
#endif
83008386
}
83018387

83028388
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
83038389
// if no bounds checks on the output are needed, we can directly write to device memory
8390+
#ifdef GGML_METAL_USE_METAL4
8391+
device float * C = (device float *) dst +
8392+
r0 + \
8393+
r1 * args.ne0 + im*args.ne1*args.ne0;
8394+
8395+
auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
8396+
cT.store(tC);
8397+
#else
83048398
device float * C = (device float *) dst +
83058399
(r0 + 32*(sgitg & 1)) + \
83068400
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
83078401

83088402
for (short i = 0; i < 8; i++) {
8309-
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0, 0, false);
8403+
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
83108404
}
8405+
#endif
83118406
} else {
83128407
// block is smaller than 64x32, we should avoid writing data outside of the matrix
83138408
threadgroup_barrier(mem_flags::mem_threadgroup);
83148409

83158410
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
83168411

8412+
#ifdef GGML_METAL_USE_METAL4
8413+
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
8414+
cT.store(tC);
8415+
#else
83178416
for (short i = 0; i < 8; i++) {
83188417
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
83198418
}
8419+
#endif
83208420

83218421
threadgroup_barrier(mem_flags::mem_threadgroup);
83228422

0 commit comments

Comments
 (0)