Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,6 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
Expand Down Expand Up @@ -1521,6 +1520,9 @@ @implementation GGMLMetalClass
NSString * key = [NSString stringWithUTF8String:name];
[ctx->kernels_ext setObject:obj forKey:key];

[metal_function release];
[obj release];

GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline,
(int) kernel.pipeline.maxTotalThreadsPerThreadgroup,
(int) kernel.pipeline.threadExecutionWidth);
Expand All @@ -1542,8 +1544,6 @@ @implementation GGMLMetalClass
char name[256];

@autoreleasepool {
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];

const int32_t dk = (int32_t) op->src[1]->ne[0];
const int32_t dv = (int32_t) op->src[2]->ne[0];

Expand Down Expand Up @@ -1575,7 +1575,7 @@ @implementation GGMLMetalClass
return res;
}

cv = [[MTLFunctionConstantValues alloc] init];
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];

[cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0];
[cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1];
Expand All @@ -1586,7 +1586,11 @@ @implementation GGMLMetalClass
[cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21];
[cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22];

return ggml_metal_compile_kernel(backend, base, name, cv);
res = ggml_metal_compile_kernel(backend, base, name, cv);

[cv release];

return res;
}
}

Expand All @@ -1604,8 +1608,6 @@ @implementation GGMLMetalClass
char name[256];

@autoreleasepool {
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];

const int32_t dk = (int32_t) op->src[1]->ne[0];
const int32_t dv = (int32_t) op->src[2]->ne[0];

Expand Down Expand Up @@ -1637,7 +1639,7 @@ @implementation GGMLMetalClass
return res;
}

cv = [[MTLFunctionConstantValues alloc] init];
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];

[cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0];
[cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1];
Expand All @@ -1649,7 +1651,11 @@ @implementation GGMLMetalClass
[cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22];
[cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23];

return ggml_metal_compile_kernel(backend, base, name, cv);
res = ggml_metal_compile_kernel(backend, base, name, cv);

[cv release];

return res;
}
}

Expand All @@ -1663,8 +1669,6 @@ @implementation GGMLMetalClass
char name[256];

@autoreleasepool {
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];

snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);

Expand All @@ -1674,12 +1678,16 @@ @implementation GGMLMetalClass
return res;
}

cv = [[MTLFunctionConstantValues alloc] init];
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];

[cv setConstantValue:&dv type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0];
[cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1];

return ggml_metal_compile_kernel(backend, base, name, cv);
res = ggml_metal_compile_kernel(backend, base, name, cv);

[cv release];

return res;
}

GGML_UNUSED(op);
Expand Down Expand Up @@ -5770,6 +5778,9 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
[cmd_buf retain];

if (ctx->cmd_bufs[n_cb].obj) {
[ctx->cmd_bufs[n_cb].obj release];
}
ctx->cmd_bufs[n_cb].obj = cmd_buf;

[cmd_buf enqueue];
Expand Down
Loading