Skip to content

Commit 9f3e11c

Browse files
committed
add dynamic memory allocation in metal
1 parent 7b6f662 commit 9f3e11c

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,8 +3094,6 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
30943094
const int32_t KH = op->src[0]->ne[1];
30953095
const int32_t KW = op->src[0]->ne[0];
30963096

3097-
GGML_ASSERT(KW * KH <= 64 && "conv_transpose_2d kernel size exceeds threadgroup memory limit");
3098-
30993097
const int32_t OW = op->ne[0];
31003098
const int32_t OH = op->ne[1];
31013099
const int32_t OC = op->ne[2];
@@ -3121,6 +3119,10 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
31213119
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
31223120
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
31233121

3122+
// Metal requires buffer size to be multiple of 16 bytes
3123+
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
3124+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3125+
31243126
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
31253127

31263128
return 1;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4146,6 +4146,7 @@ kernel void kernel_conv_transpose_2d(
41464146
device const T * src0,
41474147
device const float * src1,
41484148
device char * dst,
4149+
threadgroup float * shared_sum [[threadgroup(0)]],
41494150
uint3 tgpig[[threadgroup_position_in_grid]],
41504151
uint3 tpitg[[thread_position_in_threadgroup]],
41514152
uint3 ntg[[threads_per_threadgroup]]) {
@@ -4182,7 +4183,6 @@ kernel void kernel_conv_transpose_2d(
41824183
v += (float)src0[kernel_idx] * src1[input_idx];
41834184
}
41844185

4185-
threadgroup float shared_sum[64];
41864186
const uint tid = tpitg.y * ntg.x + tpitg.x;
41874187
shared_sum[tid] = v;
41884188

@@ -4206,6 +4206,7 @@ kernel void kernel_conv_transpose_2d<float>(
42064206
device const float * src0,
42074207
device const float * src1,
42084208
device char * dst,
4209+
threadgroup float * shared_sum [[threadgroup(0)]],
42094210
uint3 tgpig[[threadgroup_position_in_grid]],
42104211
uint3 tpitg[[thread_position_in_threadgroup]],
42114212
uint3 ntg[[threads_per_threadgroup]]);
@@ -4216,6 +4217,7 @@ kernel void kernel_conv_transpose_2d<half>(
42164217
device const half * src0,
42174218
device const float * src1,
42184219
device char * dst,
4220+
threadgroup float * shared_sum [[threadgroup(0)]],
42194221
uint3 tgpig[[threadgroup_position_in_grid]],
42204222
uint3 tpitg[[thread_position_in_threadgroup]],
42214223
uint3 ntg[[threads_per_threadgroup]]);

0 commit comments

Comments
 (0)