Skip to content

Commit adbec7f

Browse files
committed
fix im2col and add unittest for N>=1024
Signed-off-by: Junhee Yoo <[email protected]>
1 parent b4d3c16 commit adbec7f

File tree

3 files changed

+115
-6
lines changed

3 files changed

+115
-6
lines changed

ggml/src/ggml-metal.m

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
241241
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
242242
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
243243
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
244+
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
245+
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
244246
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
245247
GGML_METAL_KERNEL_TYPE_PAD_F32,
246248
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -687,6 +689,8 @@ @implementation GGMLMetalClass
687689
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
688690
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
689691
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
692+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
693+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
690694
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
691695
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
692696
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
@@ -2579,11 +2583,24 @@ static void ggml_metal_encode_node(
25792583
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
25802584
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
25812585

2582-
id<MTLComputePipelineState> pipeline = nil;
2586+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
2587+
const uint64_t M = pipeline.maxTotalThreadsPerThreadgroup;
2588+
2589+
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
25832590

25842591
switch (dst->type) {
2585-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2586-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2592+
case GGML_TYPE_F32: {
2593+
pipeline = (is_gt_mttpt ?
2594+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
2595+
:
2596+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
2597+
} break;
2598+
case GGML_TYPE_F16: {
2599+
pipeline = (is_gt_mttpt ?
2600+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
2601+
:
2602+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
2603+
} break;
25872604
default: GGML_ABORT("fatal error");
25882605
};
25892606

@@ -2602,7 +2619,16 @@ static void ggml_metal_encode_node(
26022619
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
26032620
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
26042621

2605-
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2622+
if (is_gt_mttpt) {
2623+
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
2624+
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
2625+
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
2626+
2627+
const int64_t D = N / M + (N % M > 0 ? 1 : 0);
2628+
[encoder dispatchThreadgroups:MTLSizeMake(D * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(M, 1, 1)];
2629+
} else {
2630+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2631+
}
26062632
} break;
26072633
case GGML_OP_UPSCALE:
26082634
{

ggml/src/ggml-metal.metal

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,81 @@ kernel void kernel_im2col(
19331933
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
19341934
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
19351935

1936+
typedef void (im2col_ext_t)(
1937+
device const float * x,
1938+
device char * dst,
1939+
constant int32_t & ofs0,
1940+
constant int32_t & ofs1,
1941+
constant int32_t & IW,
1942+
constant int32_t & IH,
1943+
constant int32_t & CHW,
1944+
constant int32_t & s0,
1945+
constant int32_t & s1,
1946+
constant int32_t & p0,
1947+
constant int32_t & p1,
1948+
constant int32_t & d0,
1949+
constant int32_t & d1,
1950+
constant int32_t & N,
1951+
constant int32_t & KH,
1952+
constant int32_t & KW,
1953+
uint3 tgpig[[threadgroup_position_in_grid]],
1954+
uint3 tgpg[[threadgroups_per_grid]],
1955+
uint3 tpitg[[thread_position_in_threadgroup]],
1956+
uint3 ntg[[threads_per_threadgroup]]);
1957+
1958+
template <typename T>
1959+
kernel void kernel_im2col_ext(
1960+
device const float * x,
1961+
device char * dst,
1962+
constant int32_t & ofs0,
1963+
constant int32_t & ofs1,
1964+
constant int32_t & IW,
1965+
constant int32_t & IH,
1966+
constant int32_t & CHW,
1967+
constant int32_t & s0,
1968+
constant int32_t & s1,
1969+
constant int32_t & p0,
1970+
constant int32_t & p1,
1971+
constant int32_t & d0,
1972+
constant int32_t & d1,
1973+
constant int32_t & N,
1974+
constant int32_t & KH,
1975+
constant int32_t & KW,
1976+
uint3 tgpig[[threadgroup_position_in_grid]],
1977+
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
1978+
uint3 tpitg[[thread_position_in_threadgroup]],
1979+
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
1980+
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
1981+
1982+
const int32_t d = tgpig[0] / CHW;
1983+
const int32_t chw = tgpig[0] % CHW;
1984+
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
1985+
const int32_t HW = tgpig[0] % KHW;
1986+
1987+
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
1988+
const int32_t tpitg_1 = HW / KW;
1989+
const int32_t tpitg_2 = HW % KW;
1990+
1991+
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
1992+
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
1993+
1994+
const int32_t offset_dst =
1995+
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1996+
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
1997+
1998+
device T * pdst = (device T *) (dst);
1999+
2000+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
2001+
pdst[offset_dst] = 0.0f;
2002+
} else {
2003+
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2004+
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
2005+
}
2006+
}
2007+
2008+
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
2009+
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
2010+
19362011
kernel void kernel_upscale_f32(
19372012
device const char * src0,
19382013
device char * dst,

tests/test-backend-ops.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3316,12 +3316,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
33163316
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
33173317
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
33183318

3319+
// test cases for 2D im2col
3320+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
3321+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
3322+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
3323+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
3324+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
3325+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
3326+
33193327
// sycl backend will limit task global_range < MAX_INT
33203328
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
33213329
// however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
33223330
// these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
3323-
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
3324-
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
3331+
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
3332+
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
33253333

33263334
test_cases.emplace_back(new test_conv_transpose_1d());
33273335
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));

0 commit comments

Comments
 (0)