Skip to content

Commit afebf27

Browse files
committed
cont : handle even more incompatibilities
1 parent 9af8394 commit afebf27

File tree

1 file changed

+46
-12
lines changed

1 file changed

+46
-12
lines changed

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -416,23 +416,31 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
416416
if (!mtl_function) {
417417
ggml_critical_section_end();
418418

419-
GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
419+
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
420420
if (error) {
421-
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
421+
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
422422
}
423423

424424
return nil;
425425
}
426426

427427
res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
428428

429-
ggml_metal_pipelines_add(lib->pipelines, name, res);
430-
431429
[mtl_function release];
432430

433431
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
434432
(int) res->obj.maxTotalThreadsPerThreadgroup,
435433
(int) res->obj.threadExecutionWidth);
434+
435+
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
436+
ggml_critical_section_end();
437+
438+
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
439+
440+
return nil;
441+
}
442+
443+
ggml_metal_pipelines_add(lib->pipelines, name, res);
436444
}
437445

438446
ggml_critical_section_end();
@@ -560,20 +568,27 @@ ggml_metal_device_t ggml_metal_device_init(void) {
560568
"using namespace mpp::tensor_ops; \n"
561569
" \n"
562570
"kernel void dummy_kernel( \n"
563-
" tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
564-
" tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
571+
" tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
572+
" tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
573+
" device float * C [[buffer(2)]], \n"
565574
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
566575
"{ \n"
567576
" auto tA = A.slice(0, (int)tgid.y); \n"
568577
" auto tB = B.slice((int)tgid.x, 0); \n"
569578
" \n"
570579
" matmul2d< \n"
571580
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
572-
" execution_thread> mm; \n"
581+
" execution_simdgroups<4>> mm; \n"
582+
" \n"
583+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
584+
" \n"
585+
" auto sA = tA.slice(0, 0); \n"
586+
" auto sB = tB.slice(0, 0); \n"
587+
" mm.run(sB, sA, cT); \n"
573588
" \n"
574-
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), half>(); \n"
589+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
575590
" \n"
576-
" (void)cT; \n"
591+
" cT.store(tC); \n"
577592
"}";
578593

579594
GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
@@ -582,6 +597,12 @@ ggml_metal_device_t ggml_metal_device_init(void) {
582597
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
583598
dev->props.has_tensor = false;
584599
} else {
600+
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
601+
if (!ppl) {
602+
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
603+
dev->props.has_tensor = false;
604+
}
605+
585606
ggml_metal_library_free(lib);
586607
}
587608
}
@@ -599,18 +620,25 @@ ggml_metal_device_t ggml_metal_device_init(void) {
599620
"kernel void dummy_kernel( \n"
600621
" tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
601622
" tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
623+
" device float * C [[buffer(2)]], \n"
602624
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
603625
"{ \n"
604626
" auto tA = A.slice(0, (int)tgid.y); \n"
605627
" auto tB = B.slice((int)tgid.x, 0); \n"
606628
" \n"
607629
" matmul2d< \n"
608630
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
609-
" execution_thread> mm; \n"
631+
" execution_simdgroups<4>> mm; \n"
610632
" \n"
611-
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat>(); \n"
633+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
612634
" \n"
613-
" (void)cT; \n"
635+
" auto sA = tA.slice(0, 0); \n"
636+
" auto sB = tB.slice(0, 0); \n"
637+
" mm.run(sB, sA, cT); \n"
638+
" \n"
639+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
640+
" \n"
641+
" cT.store(tC); \n"
614642
"}";
615643

616644
GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
@@ -619,6 +647,12 @@ ggml_metal_device_t ggml_metal_device_init(void) {
619647
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
620648
dev->props.has_bfloat = false;
621649
} else {
650+
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
651+
if (!ppl) {
652+
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
653+
dev->props.has_bfloat = false;
654+
}
655+
622656
ggml_metal_library_free(lib);
623657
}
624658
}

0 commit comments

Comments
 (0)