Skip to content

Commit 55758b0

Browse files
authored
metal : refactor kernel loading (#15964)
* metal : refactor bin kernels loading ggml-ci * metal : refactor rms kernel loading ggml-ci * ci : try to add memory leaks check ggml-ci * ci : try to enable memory leak detection for Mac * cont : seems to be working
1 parent f161463 commit 55758b0

File tree

4 files changed

+98
-127
lines changed

4 files changed

+98
-127
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ jobs:
8888
-DGGML_METAL_SHADER_DEBUG=ON \
8989
-DGGML_RPC=ON
9090
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
91+
leaks -atExit -- ./build/bin/test-thread-safety -hf ggml-org/gemma-3-270m-qat-GGUF -ngl 99 -p "$(printf 'hello %.0s' {1..128})" -n 16 -c 512 -ub 32 -np 2 -t 2 -lv 1
9192
9293
- name: Test
9394
id: cmake_test

ci/run.sh

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ function gg_run_ctest_with_model_debug {
270270
local model; model=$(gg_get_model)
271271
cd build-ci-debug
272272
set -e
273+
273274
(LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log
275+
274276
set +e
275277
cd ..
276278
}
@@ -281,7 +283,15 @@ function gg_run_ctest_with_model_release {
281283
local model; model=$(gg_get_model)
282284
cd build-ci-release
283285
set -e
286+
284287
(LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log
288+
289+
# test memory leaks
290+
#if [[ ! -z ${GG_BUILD_METAL} ]]; then
291+
# # TODO: this hangs for some reason ...
292+
# (time leaks -quiet -atExit -- ./bin/test-thread-safety -m $model --parallel 2 -t 2 -p "hello") 2>&1 | tee -a $OUT/${ci}-leaks.log
293+
#fi
294+
285295
set +e
286296
cd ..
287297
}
@@ -860,20 +870,15 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
860870
fi
861871

862872
ret=0
863-
if [ -z ${GG_BUILD_SYCL} ]; then
864-
# SYCL build breaks with debug build flags
865-
test $ret -eq 0 && gg_run ctest_debug
866-
fi
873+
test $ret -eq 0 && gg_run ctest_debug
867874
test $ret -eq 0 && gg_run ctest_release
868875

869876
if [ -z ${GG_BUILD_LOW_PERF} ]; then
870877
test $ret -eq 0 && gg_run embd_bge_small
871878
test $ret -eq 0 && gg_run rerank_tiny
872879

873880
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
874-
if [ -z ${GG_BUILD_SYCL} ]; then
875-
test $ret -eq 0 && gg_run test_scripts_debug
876-
fi
881+
test $ret -eq 0 && gg_run test_scripts_debug
877882
test $ret -eq 0 && gg_run test_scripts_release
878883
fi
879884

@@ -884,9 +889,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
884889
test $ret -eq 0 && gg_run pythia_2_8b
885890
#test $ret -eq 0 && gg_run open_llama_7b_v2
886891
fi
887-
if [ -z ${GG_BUILD_SYCL} ]; then
888-
test $ret -eq 0 && gg_run ctest_with_model_debug
889-
fi
892+
test $ret -eq 0 && gg_run ctest_with_model_debug
890893
test $ret -eq 0 && gg_run ctest_with_model_release
891894
fi
892895
fi

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 75 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -232,28 +232,6 @@ - (void) dealloc {
232232
@end
233233

234234
enum ggml_metal_kernel_type {
235-
GGML_METAL_KERNEL_TYPE_ADD,
236-
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
237-
GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
238-
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
239-
GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
240-
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
241-
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
242-
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
243-
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
244-
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
245-
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
246-
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
247-
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
248-
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
249-
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
250-
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
251-
GGML_METAL_KERNEL_TYPE_SUB,
252-
GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
253-
GGML_METAL_KERNEL_TYPE_MUL,
254-
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
255-
GGML_METAL_KERNEL_TYPE_DIV,
256-
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
257235
GGML_METAL_KERNEL_TYPE_ADD_ID,
258236
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
259237
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
@@ -319,9 +297,6 @@ - (void) dealloc {
319297
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
320298
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
321299
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
322-
GGML_METAL_KERNEL_TYPE_RMS_NORM,
323-
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
324-
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
325300
GGML_METAL_KERNEL_TYPE_L2_NORM,
326301
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
327302
GGML_METAL_KERNEL_TYPE_NORM,
@@ -1177,28 +1152,6 @@ @implementation GGMLMetalClass
11771152

11781153
// simd_sum and simd_max requires MTLGPUFamilyApple7
11791154

1180-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1181-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
1182-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
1183-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
1184-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
1185-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
1186-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
1187-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
1188-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
1189-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
1190-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
1191-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
1192-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
1193-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
1194-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
1195-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
1196-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
1197-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
1198-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
1199-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
1200-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1201-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
12021155
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
12031156
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
12041157
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
@@ -1264,9 +1217,6 @@ @implementation GGMLMetalClass
12641217
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
12651218
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
12661219
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1267-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1268-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1269-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
12701220
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12711221
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
12721222
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
@@ -1722,6 +1672,73 @@ @implementation GGMLMetalClass
17221672
GGML_UNUSED(op);
17231673
}
17241674

1675+
static id<MTLComputePipelineState> ggml_metal_get_pipeline_bin(
1676+
ggml_backend_t backend, enum ggml_op op,
1677+
int32_t n_fuse,
1678+
bool row) {
1679+
struct ggml_backend_metal_context * ctx = backend->context;
1680+
1681+
char base[256];
1682+
char name[256];
1683+
1684+
@autoreleasepool {
1685+
const char * op_str = "undefined";
1686+
switch (op) {
1687+
case GGML_OP_ADD: op_str = "add"; break;
1688+
case GGML_OP_SUB: op_str = "sub"; break;
1689+
case GGML_OP_MUL: op_str = "mul"; break;
1690+
case GGML_OP_DIV: op_str = "div"; break;
1691+
default: GGML_ABORT("fatal error");
1692+
};
1693+
1694+
if (row) {
1695+
snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
1696+
} else {
1697+
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
1698+
}
1699+
1700+
snprintf(name, 256, "%s", base);
1701+
1702+
id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
1703+
if (res) {
1704+
// kernel found
1705+
return res;
1706+
}
1707+
1708+
return ggml_metal_compile_kernel(backend, base, name, nil);
1709+
}
1710+
}
1711+
1712+
static id<MTLComputePipelineState> ggml_metal_get_pipeline_rms_norm(
1713+
ggml_backend_t backend, struct ggml_tensor * op,
1714+
int32_t n_fuse) {
1715+
struct ggml_backend_metal_context * ctx = backend->context;
1716+
1717+
char base[256];
1718+
char name[256];
1719+
1720+
@autoreleasepool {
1721+
switch (n_fuse) {
1722+
case 1: snprintf(base, 256, "kernel_rms_norm"); break;
1723+
case 2: snprintf(base, 256, "kernel_rms_norm_mul"); break;
1724+
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add"); break;
1725+
default: GGML_ABORT("fatal error");
1726+
}
1727+
1728+
snprintf(name, 256, "%s", base);
1729+
1730+
id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
1731+
if (res) {
1732+
// kernel found
1733+
return res;
1734+
}
1735+
1736+
return ggml_metal_compile_kernel(backend, base, name, nil);
1737+
}
1738+
1739+
GGML_UNUSED(op);
1740+
}
1741+
17251742
static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
17261743
GGML_LOG_INFO("%s: deallocating\n", __func__);
17271744

@@ -2359,8 +2376,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
23592376

23602377
bool bcast_row = false;
23612378

2362-
id<MTLComputePipelineState> pipeline = nil;
2363-
23642379
ggml_metal_kargs_bin args = {
23652380
/*.ne00 =*/ ne00,
23662381
/*.ne01 =*/ ne01,
@@ -2441,55 +2456,19 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
24412456
}
24422457
}
24432458

2459+
id<MTLComputePipelineState> pipeline = nil;
2460+
24442461
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
24452462
GGML_ASSERT(ggml_is_contiguous(src0));
24462463

24472464
// src1 is a row
24482465
GGML_ASSERT(ne11 == 1);
24492466

2450-
switch (dst->op) {
2451-
case GGML_OP_ADD:
2452-
{
2453-
switch (n_fuse) {
2454-
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
2455-
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
2456-
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
2457-
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
2458-
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
2459-
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
2460-
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
2461-
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
2462-
default: GGML_ABORT("fatal error");
2463-
}
2464-
} break;
2465-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
2466-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
2467-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
2468-
default: GGML_ABORT("fatal error");
2469-
}
2467+
pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, true);
24702468

24712469
bcast_row = true;
24722470
} else {
2473-
switch (dst->op) {
2474-
case GGML_OP_ADD:
2475-
{
2476-
switch (n_fuse) {
2477-
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
2478-
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
2479-
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
2480-
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
2481-
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
2482-
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
2483-
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
2484-
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
2485-
default: GGML_ABORT("fatal error");
2486-
}
2487-
} break;
2488-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2489-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2490-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2491-
default: GGML_ABORT("fatal error");
2492-
}
2471+
pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, false);
24932472
}
24942473

24952474
if (n_fuse > 1) {
@@ -2650,8 +2629,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26502629
ggml_metal_encode_concurrency_reset(ctx_enc);
26512630
}
26522631

2653-
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
2654-
26552632
ggml_metal_kargs_bin args = {
26562633
/*.ne00 =*/ ne00,
26572634
/*.ne01 =*/ ne01,
@@ -2681,6 +2658,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26812658
/*.o1 =*/ { offs_src1},
26822659
};
26832660

2661+
//const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
2662+
const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_bin(backend, GGML_OP_ADD, 1, false);
2663+
26842664
[encoder setComputePipelineState:pipeline];
26852665
[encoder setBytes:&args length:sizeof(args) atIndex:0];
26862666
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
@@ -4659,14 +4639,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
46594639
}
46604640
}
46614641

4662-
id<MTLComputePipelineState> pipeline;
4663-
4664-
switch (n_fuse) {
4665-
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
4666-
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
4667-
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
4668-
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
4669-
}
4642+
const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_rms_norm(backend, node, n_fuse);
46704643

46714644
int nth = 32; // SIMD width
46724645

0 commit comments

Comments
 (0)