@@ -2585,7 +2585,6 @@ static void ggml_metal_encode_node(
25852585 const int32_t ofs1 = src1->nb [is_2D ? 2 : 1 ] / 4 ;
25862586
25872587 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline ;
2588- const uint64_t M = pipeline.maxTotalThreadsPerThreadgroup ;
25892588
25902589 const bool is_gt_mttpt = ((size_t )(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup ;
25912590
@@ -2606,27 +2605,30 @@ static void ggml_metal_encode_node(
26062605 };
26072606
26082607 [encoder setComputePipelineState: pipeline];
2609- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
2610- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2611- [encoder setBytes: &ofs0 length: sizeof ( int32_t ) atIndex: 2 ];
2612- [encoder setBytes: &ofs1 length: sizeof ( int32_t ) atIndex: 3 ];
2613- [encoder setBytes: &IW length: sizeof ( int32_t ) atIndex: 4 ];
2614- [encoder setBytes: &IH length: sizeof ( int32_t ) atIndex: 5 ];
2615- [encoder setBytes: &CHW length: sizeof ( int32_t ) atIndex: 6 ];
2616- [encoder setBytes: &s0 length: sizeof ( int32_t ) atIndex: 7 ];
2617- [encoder setBytes: &s1 length: sizeof ( int32_t ) atIndex: 8 ];
2618- [encoder setBytes: &p0 length: sizeof ( int32_t ) atIndex: 9 ];
2619- [encoder setBytes: &p1 length: sizeof ( int32_t ) atIndex: 10 ];
2620- [encoder setBytes: &d0 length: sizeof ( int32_t ) atIndex: 11 ];
2621- [encoder setBytes: &d1 length: sizeof ( int32_t ) atIndex: 12 ];
2608+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
2609+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2610+ [encoder setBytes: &ofs0 length: sizeof (int32_t ) atIndex: 2 ];
2611+ [encoder setBytes: &ofs1 length: sizeof (int32_t ) atIndex: 3 ];
2612+ [encoder setBytes: &IW length: sizeof (int32_t ) atIndex: 4 ];
2613+ [encoder setBytes: &IH length: sizeof (int32_t ) atIndex: 5 ];
2614+ [encoder setBytes: &CHW length: sizeof (int32_t ) atIndex: 6 ];
2615+ [encoder setBytes: &s0 length: sizeof (int32_t ) atIndex: 7 ];
2616+ [encoder setBytes: &s1 length: sizeof (int32_t ) atIndex: 8 ];
2617+ [encoder setBytes: &p0 length: sizeof (int32_t ) atIndex: 9 ];
2618+ [encoder setBytes: &p1 length: sizeof (int32_t ) atIndex: 10 ];
2619+ [encoder setBytes: &d0 length: sizeof (int32_t ) atIndex: 11 ];
2620+ [encoder setBytes: &d1 length: sizeof (int32_t ) atIndex: 12 ];
26222621
26232622 if (is_gt_mttpt) {
2624- [encoder setBytes: &N length: sizeof (int32_t ) atIndex: 13 ];
2625- [encoder setBytes: &KH length: sizeof (int32_t ) atIndex: 14 ];
2626- [encoder setBytes: &KW length: sizeof (int32_t ) atIndex: 15 ];
2623+ [encoder setBytes: &N length: sizeof (int32_t ) atIndex: 13 ];
2624+ [encoder setBytes: &KH length: sizeof (int32_t ) atIndex: 14 ];
2625+ [encoder setBytes: &KW length: sizeof (int32_t ) atIndex: 15 ];
2626+
2627+ const uint64_t n_threads = MIN (pipeline.maxTotalThreadsPerThreadgroup , (uint64_t )N);
2628+
2629+ const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0 );
26272630
2628- const int64_t D = N / M + (N % M > 0 ? 1 : 0 );
2629- [encoder dispatchThreadgroups: MTLSizeMake (D * CHW, OH, OW) threadsPerThreadgroup: MTLSizeMake (MIN ((uint64_t )N, M), 1 , 1 )];
2631+ [encoder dispatchThreadgroups: MTLSizeMake (quotient * CHW, OH, OW) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
26302632 } else {
26312633 [encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
26322634 }
0 commit comments