@@ -303,6 +303,72 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
303303 return res;
304304}
305305
306+ ggml_metal_library_t ggml_metal_library_init_from_source (ggml_metal_device_t dev, const char * source, bool verbose) {
307+ if (source == NULL ) {
308+ GGML_LOG_ERROR (" %s : source is NULL\n " , __func__);
309+ return NULL ;
310+ }
311+
312+ id <MTLDevice > device = ggml_metal_device_get_obj (dev);
313+ id <MTLLibrary > library = nil ;
314+ NSError * error = nil ;
315+
316+ const int64_t t_start = ggml_time_us ();
317+
318+ NSString * src = [[NSString alloc ] initWithBytes: source
319+ length: strlen (source)
320+ encoding: NSUTF8StringEncoding];
321+ if (!src) {
322+ GGML_LOG_ERROR (" %s : failed to create NSString from source\n " , __func__);
323+ return NULL ;
324+ }
325+
326+ @autoreleasepool {
327+ NSMutableDictionary * prep = [NSMutableDictionary dictionary ];
328+
329+ MTLCompileOptions * options = [MTLCompileOptions new ];
330+ options.preprocessorMacros = prep;
331+
332+ library = [device newLibraryWithSource: src options: options error: &error];
333+ if (error) {
334+ if (verbose) {
335+ GGML_LOG_ERROR (" %s : error compiling source: %s \n " , __func__, [[error description ] UTF8String ]);
336+ } else {
337+ GGML_LOG_ERROR (" %s : error compiling source\n " , __func__);
338+ }
339+ library = nil ;
340+ }
341+
342+ [options release ];
343+ }
344+
345+ [src release ];
346+
347+ if (!library) {
348+ if (verbose) {
349+ GGML_LOG_ERROR (" %s : failed to create Metal library from source\n " , __func__);
350+ }
351+
352+ return NULL ;
353+ }
354+
355+ if (verbose) {
356+ GGML_LOG_INFO (" %s : compiled in %.3f sec\n " , __func__, (ggml_time_us () - t_start) / 1e6 );
357+ }
358+
359+ ggml_metal_library_t res = calloc (1 , sizeof (struct ggml_metal_library));
360+ if (!res) {
361+ GGML_LOG_ERROR (" %s : calloc failed\n " , __func__);
362+ return NULL ;
363+ }
364+
365+ res->obj = library;
366+ res->device = device;
367+ res->pipelines = ggml_metal_pipelines_init ();
368+
369+ return res;
370+ }
371+
306372void ggml_metal_library_free (ggml_metal_library_t lib) {
307373 if (!lib) {
308374 return ;
@@ -474,12 +540,56 @@ ggml_metal_device_t ggml_metal_device_init(void) {
474540
475541 dev->props .has_bfloat = [dev->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
476542 dev->props .has_bfloat |= [dev->mtl_device supportsFamily: MTLGPUFamilyApple6];
543+ if (getenv (" GGML_METAL_BF16_DISABLE" ) != NULL ) {
544+ dev->props .has_bfloat = false ;
545+ }
477546
478547 dev->props .has_tensor = [dev->mtl_device supportsFamily: MTLGPUFamilyMetal4_GGML];
479548 if (getenv (" GGML_METAL_TENSOR_DISABLE" ) != NULL ) {
480549 dev->props .has_tensor = false ;
481550 }
482551
552+ // try to compile a dummy tensor kernel to determine if the tensor API is supported for bfloat
553+ if (dev->props .has_tensor && dev->props .has_bfloat ) {
554+ const char * src_tensor_bf16 = " \n "
555+ " #include <metal_stdlib> \n "
556+ " #include <metal_tensor> \n "
557+ " #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n "
558+ " \n "
559+ " using namespace metal; \n "
560+ " using namespace mpp::tensor_ops; \n "
561+ " \n "
562+ " kernel void bfloat_dummy_kernel( \n "
563+ " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n "
564+ " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n "
565+ " uint2 tgid [[threadgroup_position_in_grid]]) \n "
566+ " { \n "
567+ " // Create slices for this threadgroup (no real computation performed). \n "
568+ " auto tA = A.slice(0, (int)tgid.y); \n "
569+ " auto tB = B.slice((int)tgid.x, 0); \n "
570+ " \n "
571+ " // Minimal matmul descriptor: 8×8 tile with dynamic K dimension. \n "
572+ " matmul2d< \n "
573+ " matmul2d_descriptor(8, 8, dynamic_extent), \n "
574+ " execution_thread> mm; \n "
575+ " \n "
576+ " // Obtain a cooperative destination tensor of bfloat type. \n "
577+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat>(); \n "
578+ " \n "
579+ " // Silence “unused variable” warnings. \n "
580+ " (void)cT; \n "
581+ " }" ;
582+
583+ GGML_LOG_INFO (" %s : testing tensor API for bfloat support\n " , __func__);
584+ ggml_metal_library_t lib = ggml_metal_library_init_from_source (dev, src_tensor_bf16, false );
585+ if (lib == NULL ) {
586+ GGML_LOG_WARN (" %s : - the tensor API does not support bfloat - disabling bfloat support\n " , __func__);
587+ dev->props .has_bfloat = false ;
588+ } else {
589+ ggml_metal_library_free (lib);
590+ }
591+ }
592+
483593 dev->props .use_residency_sets = true ;
484594#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
485595 dev->props .use_residency_sets = getenv (" GGML_METAL_NO_RESIDENCY" ) == nil ;
0 commit comments