@@ -416,23 +416,31 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
416416 if (!mtl_function) {
417417 ggml_critical_section_end ();
418418
419- GGML_LOG_ERROR (" %s : error: failed to compile pipeline: base = '%s ', name = '%s '\n " , __func__, base, name);
419+ GGML_LOG_ERROR (" %s : failed to compile pipeline: base = '%s ', name = '%s '\n " , __func__, base, name);
420420 if (error) {
421- GGML_LOG_ERROR (" %s : error: %s \n " , __func__, [[error description ] UTF8String ]);
421+ GGML_LOG_ERROR (" %s : %s \n " , __func__, [[error description ] UTF8String ]);
422422 }
423423
424424 return nil ;
425425 }
426426
427427 res->obj = [lib->device newComputePipelineStateWithFunction: mtl_function error: &error];
428428
429- ggml_metal_pipelines_add (lib->pipelines , name, res);
430-
431429 [mtl_function release ];
432430
433431 GGML_LOG_DEBUG (" %s : loaded %-40s %16p | th_max = %4d | th_width = %4d \n " , __func__, name, (void *) res->obj ,
434432 (int ) res->obj .maxTotalThreadsPerThreadgroup ,
435433 (int ) res->obj .threadExecutionWidth );
434+
435+ if (res->obj .maxTotalThreadsPerThreadgroup == 0 || res->obj .threadExecutionWidth == 0 ) {
436+ ggml_critical_section_end ();
437+
438+ GGML_LOG_ERROR (" %s : incompatible pipeline %s \n " , __func__, name);
439+
440+ return nil ;
441+ }
442+
443+ ggml_metal_pipelines_add (lib->pipelines , name, res);
436444 }
437445
438446 ggml_critical_section_end ();
@@ -560,20 +568,27 @@ ggml_metal_device_t ggml_metal_device_init(void) {
560568 " using namespace mpp::tensor_ops; \n "
561569 " \n "
562570 " kernel void dummy_kernel( \n "
563- " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n "
564- " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n "
571+ " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n "
572+ " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n "
573+ " device float * C [[buffer(2)]], \n "
565574 " uint2 tgid [[threadgroup_position_in_grid]]) \n "
566575 " { \n "
567576 " auto tA = A.slice(0, (int)tgid.y); \n "
568577 " auto tB = B.slice((int)tgid.x, 0); \n "
569578 " \n "
570579 " matmul2d< \n "
571580 " matmul2d_descriptor(8, 8, dynamic_extent), \n "
572- " execution_thread> mm; \n "
581+ " execution_simdgroups<4>> mm; \n "
582+ " \n "
583+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n "
584+ " \n "
585+ " auto sA = tA.slice(0, 0); \n "
586+ " auto sB = tB.slice(0, 0); \n "
587+ " mm.run(sB, sA, cT); \n "
573588 " \n "
574- " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), half>( ); \n "
589+ " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4) ); \n "
575590 " \n "
576- " (void)cT ; \n "
591+ " cT.store(tC) ; \n "
577592 " }" ;
578593
579594 GGML_LOG_INFO (" %s : testing tensor API for f16 support\n " , __func__);
@@ -582,6 +597,12 @@ ggml_metal_device_t ggml_metal_device_init(void) {
582597 GGML_LOG_WARN (" %s : - the tensor API is not supported in this environment - disabling\n " , __func__);
583598 dev->props .has_tensor = false ;
584599 } else {
600+ ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline (lib, " dummy_kernel" , " dummy_kernel" , nil );
601+ if (!ppl) {
602+ GGML_LOG_WARN (" %s : - the tensor API is not supported in this environment - disabling\n " , __func__);
603+ dev->props .has_tensor = false ;
604+ }
605+
585606 ggml_metal_library_free (lib);
586607 }
587608 }
@@ -599,18 +620,25 @@ ggml_metal_device_t ggml_metal_device_init(void) {
599620 " kernel void dummy_kernel( \n "
600621 " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n "
601622 " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n "
623+ " device float * C [[buffer(2)]], \n "
602624 " uint2 tgid [[threadgroup_position_in_grid]]) \n "
603625 " { \n "
604626 " auto tA = A.slice(0, (int)tgid.y); \n "
605627 " auto tB = B.slice((int)tgid.x, 0); \n "
606628 " \n "
607629 " matmul2d< \n "
608630 " matmul2d_descriptor(8, 8, dynamic_extent), \n "
609- " execution_thread > mm; \n "
631+ " execution_simdgroups<4> > mm; \n "
610632 " \n "
611- " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat >(); \n "
633+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float >(); \n "
612634 " \n "
613- " (void)cT; \n "
635+ " auto sA = tA.slice(0, 0); \n "
636+ " auto sB = tB.slice(0, 0); \n "
637+ " mm.run(sB, sA, cT); \n "
638+ " \n "
639+ " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n "
640+ " \n "
641+ " cT.store(tC); \n "
614642 " }" ;
615643
616644 GGML_LOG_INFO (" %s : testing tensor API for bfloat support\n " , __func__);
@@ -619,6 +647,12 @@ ggml_metal_device_t ggml_metal_device_init(void) {
619647 GGML_LOG_WARN (" %s : - the tensor API does not support bfloat - disabling bfloat support\n " , __func__);
620648 dev->props .has_bfloat = false ;
621649 } else {
650+ ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline (lib, " dummy_kernel" , " dummy_kernel" , nil );
651+ if (!ppl) {
652+ GGML_LOG_WARN (" %s : - the tensor API does not support bfloat - disabling bfloat support\n " , __func__);
653+ dev->props .has_bfloat = false ;
654+ }
655+
622656 ggml_metal_library_free (lib);
623657 }
624658 }
0 commit comments