Skip to content

Commit ac4d564

Browse files
committed
metal : detect tensor support
1 parent 4d1783a commit ac4d564

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

ggml/src/ggml-metal/ggml-metal-context.m

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
// additional, inference-time compiled pipelines
3636
ggml_metal_pipelines_t pipelines_ext;
3737

38-
bool use_bfloat;
3938
bool use_fusion;
4039
bool use_concurrency;
4140
bool use_graph_optimize;
@@ -121,11 +120,10 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
121120
}
122121
}
123122

124-
const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
123+
//const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
125124

126125
res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
127126

128-
res->use_bfloat = props_dev->has_bfloat;
129127
res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
130128
res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
131129

@@ -147,7 +145,6 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
147145

148146
memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
149147

150-
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false");
151148
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
152149
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
153150
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ struct ggml_metal_device_props {
193193
bool has_simdgroup_mm;
194194
bool has_unified_memory;
195195
bool has_bfloat;
196+
bool has_tensor;
196197
bool use_residency_sets;
197198
bool use_shared_buffers;
198199

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
#define GGML_METAL_HAS_RESIDENCY_SETS 1
2222
#endif
2323

24-
// overload of MTLGPUFamilyMetal3 (not available in some environments)
24+
// overload of MTLGPUFamilyMetalX (not available in some environments)
2525
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
26+
static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
2627

2728
// virtual address for GPU memory allocations
2829
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
@@ -261,6 +262,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
261262
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
262263
}
263264

265+
if (ggml_metal_device_get_props(dev)->has_tensor) {
266+
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
267+
}
268+
264269
#if GGML_METAL_EMBED_LIBRARY
265270
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
266271
#endif
@@ -470,6 +475,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
470475
dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
471476
dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
472477

478+
dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
479+
473480
dev->props.use_residency_sets = true;
474481
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
475482
dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
@@ -529,6 +536,7 @@ ggml_metal_device_t ggml_metal_device_init(void) {
529536
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false");
530537
GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false");
531538
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false");
539+
GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
532540
GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
533541
GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
534542

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ __embed_ggml-common.h__
99

1010
#include <metal_stdlib>
1111

12-
#define GGML_METAL_USE_METAL4
13-
14-
#ifdef GGML_METAL_USE_METAL4
12+
#ifdef GGML_METAL_HAS_TENSOR
1513
#include <metal_tensor>
1614

1715
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
@@ -8196,7 +8194,7 @@ kernel void kernel_mul_mm(
81968194
+ args.nb11*(r1 + lr1)
81978195
+ args.nb10*iy);
81988196

8199-
#ifndef GGML_METAL_USE_METAL4
8197+
#ifndef GGML_METAL_HAS_TENSOR
82008198
S0_8x8 ma[4];
82018199
S1_8x8 mb[2];
82028200

@@ -8217,7 +8215,7 @@ kernel void kernel_mul_mm(
82178215
#endif
82188216

82198217
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8220-
#ifndef GGML_METAL_USE_METAL4
8218+
#ifndef GGML_METAL_HAS_TENSOR
82218219
// load data and store to threadgroup memory
82228220
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
82238221
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -8397,7 +8395,7 @@ kernel void kernel_mul_mm(
83978395

83988396
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
83998397
// if no bounds checks on the output are needed, we can directly write to device memory
8400-
#ifdef GGML_METAL_USE_METAL4
8398+
#ifdef GGML_METAL_HAS_TENSOR
84018399
device float * C = (device float *) dst +
84028400
r0 + \
84038401
r1 * args.ne0 + im*args.ne1*args.ne0;
@@ -8419,7 +8417,7 @@ kernel void kernel_mul_mm(
84198417

84208418
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
84218419

8422-
#ifdef GGML_METAL_USE_METAL4
8420+
#ifdef GGML_METAL_HAS_TENSOR
84238421
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
84248422
cT.store(tC);
84258423
#else

0 commit comments

Comments
 (0)