@@ -303,6 +303,72 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
303303    return  res;
304304}
305305
306+ ggml_metal_library_t  ggml_metal_library_init_from_source (ggml_metal_device_t  dev, const  char  * source, bool  verbose) {
307+     if  (source == NULL ) {
308+         GGML_LOG_ERROR (" %s : source is NULL\n " 
309+         return  NULL ;
310+     }
311+ 
312+     id <MTLDevice > device = ggml_metal_device_get_obj (dev);
313+     id <MTLLibrary > library = nil ;
314+     NSError  * error = nil ;
315+ 
316+     const  int64_t  t_start = ggml_time_us ();
317+ 
318+     NSString  * src = [[NSString  alloc ] initWithBytes: source
319+                                               length: strlen (source)
320+                                             encoding: NSUTF8StringEncoding];
321+     if  (!src) {
322+         GGML_LOG_ERROR (" %s : failed to create NSString from source\n " 
323+         return  NULL ;
324+     }
325+ 
326+     @autoreleasepool {
327+         NSMutableDictionary  * prep = [NSMutableDictionary  dictionary ];
328+ 
329+         MTLCompileOptions  * options = [MTLCompileOptions  new ];
330+         options.preprocessorMacros  = prep;
331+ 
332+         library = [device newLibraryWithSource: src options: options error: &error];
333+         if  (error) {
334+             if  (verbose) {
335+                 GGML_LOG_ERROR (" %s : error compiling source: %s \n " description ] UTF8String ]);
336+             } else  {
337+                 GGML_LOG_ERROR (" %s : error compiling source\n " 
338+             }
339+             library = nil ;
340+         }
341+ 
342+         [options release ];
343+     }
344+ 
345+     [src release ];
346+ 
347+     if  (!library) {
348+         if  (verbose) {
349+             GGML_LOG_ERROR (" %s : failed to create Metal library from source\n " 
350+         }
351+ 
352+         return  NULL ;
353+     }
354+ 
355+     if  (verbose) {
356+         GGML_LOG_INFO (" %s : compiled in %.3f  sec\n " ggml_time_us () - t_start) / 1e6 );
357+     }
358+ 
359+     ggml_metal_library_t  res = calloc (1 , sizeof (struct  ggml_metal_library));
360+     if  (!res) {
361+         GGML_LOG_ERROR (" %s : calloc failed\n " 
362+         return  NULL ;
363+     }
364+ 
365+     res->obj        = library;
366+     res->device     = device;
367+     res->pipelines  = ggml_metal_pipelines_init ();
368+ 
369+     return  res;
370+ }
371+ 
306372void  ggml_metal_library_free (ggml_metal_library_t  lib) {
307373    if  (!lib) {
308374        return ;
@@ -474,12 +540,56 @@ ggml_metal_device_t ggml_metal_device_init(void) {
474540
475541            dev->props .has_bfloat   = [dev->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
476542            dev->props .has_bfloat  |= [dev->mtl_device supportsFamily: MTLGPUFamilyApple6];
543+             if  (getenv (" GGML_METAL_BF16_DISABLE" NULL ) {
544+                 dev->props .has_bfloat  = false ;
545+             }
477546
478547            dev->props .has_tensor  = [dev->mtl_device supportsFamily: MTLGPUFamilyMetal4_GGML];
479548            if  (getenv (" GGML_METAL_TENSOR_DISABLE" NULL ) {
480549                dev->props .has_tensor  = false ;
481550            }
482551
552+             //  try to compile a dummy tensor kernel to determine if the tensor API is supported for bfloat
553+             if  (dev->props .has_tensor  && dev->props .has_bfloat ) {
554+                 const  char  * src_tensor_bf16 = " \n " 
555+                     " #include <metal_stdlib> \n " 
556+                     " #include <metal_tensor> \n " 
557+                     " #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n " 
558+                     "  \n " 
559+                     " using namespace metal; \n " 
560+                     " using namespace mpp::tensor_ops; \n " 
561+                     "  \n " 
562+                     " kernel void bfloat_dummy_kernel( \n " 
563+                     "     tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n " 
564+                     "     tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n " 
565+                     "     uint2 tgid [[threadgroup_position_in_grid]]) \n " 
566+                     " { \n " 
567+                     "     // Create slices for this threadgroup (no real computation performed). \n " 
568+                     "     auto tA = A.slice(0, (int)tgid.y); \n " 
569+                     "     auto tB = B.slice((int)tgid.x, 0); \n " 
570+                     "  \n " 
571+                     "     // Minimal matmul descriptor: 8×8 tile with dynamic K dimension. \n " 
572+                     "     matmul2d< \n " 
573+                     "         matmul2d_descriptor(8, 8, dynamic_extent), \n " 
574+                     "         execution_thread> mm; \n " 
575+                     "  \n " 
576+                     "     // Obtain a cooperative destination tensor of bfloat type. \n " 
577+                     "     auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat>(); \n " 
578+                     "  \n " 
579+                     "     // Silence “unused variable” warnings. \n " 
580+                     "     (void)cT; \n " 
581+                     " }" 
582+ 
583+                 GGML_LOG_INFO (" %s : testing tensor API for bfloat support\n " 
584+                 ggml_metal_library_t  lib = ggml_metal_library_init_from_source (dev, src_tensor_bf16, false );
585+                 if  (lib == NULL ) {
586+                     GGML_LOG_WARN (" %s : - the tensor API does not support bfloat - disabling bfloat support\n " 
587+                     dev->props .has_bfloat  = false ;
588+                 } else  {
589+                     ggml_metal_library_free (lib);
590+                 }
591+             }
592+ 
483593            dev->props .use_residency_sets  = true ;
484594#if  defined(GGML_METAL_HAS_RESIDENCY_SETS)
485595            dev->props .use_residency_sets  = getenv (" GGML_METAL_NO_RESIDENCY" nil ;
0 commit comments