Skip to content

Commit 3c2b87d

Browse files
committed
fix more formatting and enhance readability
Signed-off-by: Junhee Yoo <[email protected]>
1 parent 746e79e commit 3c2b87d

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

ggml/src/ggml-metal.m

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,7 +2585,6 @@ static void ggml_metal_encode_node(
25852585
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
25862586

25872587
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
2588-
const uint64_t M = pipeline.maxTotalThreadsPerThreadgroup;
25892588

25902589
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
25912590

@@ -2606,27 +2605,30 @@ static void ggml_metal_encode_node(
26062605
};
26072606

26082607
[encoder setComputePipelineState:pipeline];
2609-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2610-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2611-
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2612-
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2613-
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2614-
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2615-
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2616-
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2617-
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2618-
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2619-
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2620-
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2621-
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2608+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2609+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2610+
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
2611+
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
2612+
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
2613+
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
2614+
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
2615+
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
2616+
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
2617+
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
2618+
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
2619+
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
2620+
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
26222621

26232622
if (is_gt_mttpt) {
2624-
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
2625-
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
2626-
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
2623+
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
2624+
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
2625+
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
2626+
2627+
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
2628+
2629+
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
26272630

2628-
const int64_t D = N / M + (N % M > 0 ? 1 : 0);
2629-
[encoder dispatchThreadgroups:MTLSizeMake(D * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(MIN((uint64_t)N, M), 1, 1)];
2631+
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
26302632
} else {
26312633
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
26322634
}

0 commit comments

Comments
 (0)