Skip to content

Commit 7b6f662

Browse files
committed
add more tests, add optimization to threading
1 parent 2f1ed3c commit 7b6f662

File tree

3 files changed

+39
-21
lines changed

3 files changed

+39
-21
lines changed

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3094,6 +3094,8 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
30943094
const int32_t KH = op->src[0]->ne[1];
30953095
const int32_t KW = op->src[0]->ne[0];
30963096

3097+
GGML_ASSERT(KW * KH <= 64 && "conv_transpose_2d kernel size exceeds threadgroup memory limit");
3098+
30973099
const int32_t OW = op->ne[0];
30983100
const int32_t OH = op->ne[1];
30993101
const int32_t OC = op->ne[2];
@@ -3119,7 +3121,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
31193121
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
31203122
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
31213123

3122-
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, 1, 1, 1);
3124+
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
31233125

31243126
return 1;
31253127
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4147,45 +4147,57 @@ kernel void kernel_conv_transpose_2d(
41474147
device const float * src1,
41484148
device char * dst,
41494149
uint3 tgpig[[threadgroup_position_in_grid]],
4150-
uint3 tgpg[[threadgroups_per_grid]]) {
4150+
uint3 tpitg[[thread_position_in_threadgroup]],
4151+
uint3 ntg[[threads_per_threadgroup]]) {
41514152

41524153
const int64_t out_x = tgpig[0];
41534154
const int64_t out_y = tgpig[1];
41544155
const int64_t out_c = tgpig[2];
41554156

4157+
const int64_t kw = tpitg[0];
4158+
const int64_t kh = tpitg[1];
4159+
41564160
float v = 0.0f;
41574161

41584162
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
4159-
for (int64_t kh = 0; kh < args.KH; kh++) {
4163+
int64_t in_y = out_y - kh;
4164+
4165+
if (in_y < 0 || in_y % args.s0) continue;
41604166

4161-
int64_t in_y = out_y - kh;
4167+
in_y /= args.s0;
41624168

4163-
if (in_y < 0 || in_y % args.s0) continue;
4169+
if (in_y >= args.IH) continue;
41644170

4165-
in_y /= args.s0;
4171+
int64_t in_x = out_x - kw;
41664172

4167-
if (in_y >= args.IH) continue;
4173+
if (in_x < 0 || in_x % args.s0) continue;
41684174

4169-
for (int64_t kw = 0; kw < args.KW; kw++) {
4170-
int64_t in_x = out_x - kw;
4175+
in_x /= args.s0;
41714176

4172-
if (in_x < 0 || in_x % args.s0) continue;
4177+
if (in_x >= args.IW) continue;
41734178

4174-
in_x /= args.s0;
4179+
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
4180+
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
41754181

4176-
if (in_x >= args.IW) continue;
4182+
v += (float)src0[kernel_idx] * src1[input_idx];
4183+
}
41774184

4178-
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
4179-
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
4185+
threadgroup float shared_sum[64];
4186+
const uint tid = tpitg.y * ntg.x + tpitg.x;
4187+
shared_sum[tid] = v;
41804188

4181-
v += (float)src0[kernel_idx] * src1[input_idx];
4189+
threadgroup_barrier(mem_flags::mem_threadgroup);
41824190

4183-
}
4191+
if (tid == 0) {
4192+
float total = 0.0f;
4193+
const uint num_threads = ntg.x * ntg.y;
4194+
for (uint i = 0; i < num_threads; i++) {
4195+
total += shared_sum[i];
41844196
}
4185-
}
4186-
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
41874197

4188-
dst_ptr[0] = v;
4198+
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
4199+
dst_ptr[0] = total;
4200+
}
41894201
}
41904202

41914203
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
@@ -4195,7 +4207,8 @@ kernel void kernel_conv_transpose_2d<float>(
41954207
device const float * src1,
41964208
device char * dst,
41974209
uint3 tgpig[[threadgroup_position_in_grid]],
4198-
uint3 tgpg[[threadgroups_per_grid]]);
4210+
uint3 tpitg[[thread_position_in_threadgroup]],
4211+
uint3 ntg[[threads_per_threadgroup]]);
41994212

42004213
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
42014214
kernel void kernel_conv_transpose_2d<half>(
@@ -4204,7 +4217,8 @@ kernel void kernel_conv_transpose_2d<half>(
42044217
device const float * src1,
42054218
device char * dst,
42064219
uint3 tgpig[[threadgroup_position_in_grid]],
4207-
uint3 tgpg[[threadgroups_per_grid]]);
4220+
uint3 tpitg[[thread_position_in_threadgroup]],
4221+
uint3 ntg[[threads_per_threadgroup]]);
42084222

42094223
kernel void kernel_upscale_f32(
42104224
constant ggml_metal_kargs_upscale & args,

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6952,6 +6952,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
69526952
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
69536953

69546954
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
6955+
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
6956+
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
69556957

69566958
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
69576959

0 commit comments

Comments
 (0)