Skip to content

Commit f2927f4

Browse files
committed
metal : fix check for bfloat tensor support
1 parent 7984d57 commit f2927f4

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
9595

9696
typedef struct ggml_metal_library * ggml_metal_library_t;
9797

98-
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev);
98+
ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev);
99+
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose);
100+
99101
void ggml_metal_library_free(ggml_metal_library_t lib);
100102

101103
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,72 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
303303
return res;
304304
}
305305

306+
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
307+
if (source == NULL) {
308+
GGML_LOG_ERROR("%s: source is NULL\n", __func__);
309+
return NULL;
310+
}
311+
312+
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
313+
id<MTLLibrary> library = nil;
314+
NSError * error = nil;
315+
316+
const int64_t t_start = ggml_time_us();
317+
318+
NSString * src = [[NSString alloc] initWithBytes:source
319+
length:strlen(source)
320+
encoding:NSUTF8StringEncoding];
321+
if (!src) {
322+
GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
323+
return NULL;
324+
}
325+
326+
@autoreleasepool {
327+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
328+
329+
MTLCompileOptions * options = [MTLCompileOptions new];
330+
options.preprocessorMacros = prep;
331+
332+
library = [device newLibraryWithSource:src options:options error:&error];
333+
if (error) {
334+
if (verbose) {
335+
GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
336+
} else {
337+
GGML_LOG_ERROR("%s: error compiling source\n", __func__);
338+
}
339+
library = nil;
340+
}
341+
342+
[options release];
343+
}
344+
345+
[src release];
346+
347+
if (!library) {
348+
if (verbose) {
349+
GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
350+
}
351+
352+
return NULL;
353+
}
354+
355+
if (verbose) {
356+
GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
357+
}
358+
359+
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
360+
if (!res) {
361+
GGML_LOG_ERROR("%s: calloc failed\n", __func__);
362+
return NULL;
363+
}
364+
365+
res->obj = library;
366+
res->device = device;
367+
res->pipelines = ggml_metal_pipelines_init();
368+
369+
return res;
370+
}
371+
306372
void ggml_metal_library_free(ggml_metal_library_t lib) {
307373
if (!lib) {
308374
return;
@@ -474,12 +540,56 @@ ggml_metal_device_t ggml_metal_device_init(void) {
474540

475541
dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
476542
dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
543+
if (getenv("GGML_METAL_BF16_DISABLE") != NULL) {
544+
dev->props.has_bfloat = false;
545+
}
477546

478547
dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
479548
if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) {
480549
dev->props.has_tensor = false;
481550
}
482551

552+
// try to compile a dummy tensor kernel to determine if the tensor API is supported for bfloat
553+
if (dev->props.has_tensor && dev->props.has_bfloat) {
554+
const char * src_tensor_bf16 = "\n"
555+
"#include <metal_stdlib> \n"
556+
"#include <metal_tensor> \n"
557+
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
558+
" \n"
559+
"using namespace metal; \n"
560+
"using namespace mpp::tensor_ops; \n"
561+
" \n"
562+
"kernel void bfloat_dummy_kernel( \n"
563+
" tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
564+
" tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
565+
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
566+
"{ \n"
567+
" // Create slices for this threadgroup (no real computation performed). \n"
568+
" auto tA = A.slice(0, (int)tgid.y); \n"
569+
" auto tB = B.slice((int)tgid.x, 0); \n"
570+
" \n"
571+
" // Minimal matmul descriptor: 8×8 tile with dynamic K dimension. \n"
572+
" matmul2d< \n"
573+
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
574+
" execution_thread> mm; \n"
575+
" \n"
576+
" // Obtain a cooperative destination tensor of bfloat type. \n"
577+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat>(); \n"
578+
" \n"
579+
" // Silence “unused variable” warnings. \n"
580+
" (void)cT; \n"
581+
"}";
582+
583+
GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
584+
ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
585+
if (lib == NULL) {
586+
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
587+
dev->props.has_bfloat = false;
588+
} else {
589+
ggml_metal_library_free(lib);
590+
}
591+
}
592+
483593
dev->props.use_residency_sets = true;
484594
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
485595
dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;

0 commit comments

Comments
 (0)