Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 111 additions & 19 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
Expand Down Expand Up @@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,

GGML_METAL_KERNEL_TYPE_COUNT
};
Expand Down Expand Up @@ -685,6 +689,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
Expand Down Expand Up @@ -716,6 +722,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}

[metal_library release];
Expand Down Expand Up @@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D:
return false;
case GGML_OP_POOL_2D:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
Expand Down Expand Up @@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
Expand Down Expand Up @@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;

id<MTLComputePipelineState> pipeline = nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;

const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to keep just the _ext variant of the kernel? Does the old kernel have a significant advantage?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: To keep backward compatibility and performance in smaller computations, I suggest you keep the old kernel.

details

I’ve done some investigations after you gave me this comment.
The new kernel(a.k.a _ext) has performance degradation in small size of N:

$ test-backend-ops perf -o IM2COL -b Metal
NOTE: 9 x 113 is 1017, 9 x 114 is 1026 which exceeds limits.

original + _ext, M
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):               229348 runs -     4.38 us/run -       99 kB/run -    0.77 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):               196584 runs -     5.24 us/run -      199 kB/run -    1.51 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):             147438 runs -     6.82 us/run -      351 kB/run -    2.73 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):             106483 runs -     9.67 us/run -      703 kB/run -    5.34 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              57337 runs -    18.79 us/run -      354 kB/run -    2.57 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              57337 runs -    18.74 us/run -      354 kB/run -    2.58 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              49146 runs -    21.43 us/run -      622 kB/run -    4.61 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              32764 runs -    35.12 us/run -     1244 kB/run -    8.45 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    16382 runs -    74.60 us/run -     3186 kB/run -   20.37 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    10532 runs -   144.71 us/run -     6372 kB/run -   21.00 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    10532 runs -   143.44 us/run -     6372 kB/run -   21.18 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     5266 runs -   283.55 us/run -    12744 kB/run -   21.43 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     8426 runs -   175.97 us/run -     7965 kB/run -   21.58 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     4214 runs -   352.93 us/run -    15930 kB/run -   21.52 GB/s

only _ext, M - worst case.
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                73719 runs -    14.48 us/run -       99 kB/run -    0.73 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                49146 runs -    24.31 us/run -      199 kB/run -    1.30 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              57337 runs -    17.90 us/run -      351 kB/run -    2.68 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              40955 runs -    28.89 us/run -      703 kB/run -    4.64 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              57337 runs -    18.54 us/run -      354 kB/run -    2.61 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              57337 runs -    18.06 us/run -      354 kB/run -    2.68 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              57337 runs -    20.35 us/run -      622 kB/run -    4.17 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              32764 runs -    34.79 us/run -     1244 kB/run -    8.53 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    16382 runs -    74.19 us/run -     3186 kB/run -   20.48 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    10532 runs -   144.72 us/run -     6372 kB/run -   21.00 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    10532 runs -   142.86 us/run -     6372 kB/run -   21.27 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     5266 runs -   283.70 us/run -    12744 kB/run -   21.42 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     8426 runs -   175.84 us/run -     7965 kB/run -   21.60 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     4214 runs -   350.67 us/run -    15930 kB/run -   21.66 GB/s

only_ext, MIN(N, M)
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):               131056 runs -     7.78 us/run -       99 kB/run -    0.76 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                90101 runs -    11.24 us/run -      199 kB/run -    1.54 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              90101 runs -    11.58 us/run -      351 kB/run -    2.63 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              57337 runs -    17.85 us/run -      703 kB/run -    5.37 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              90101 runs -    11.80 us/run -      354 kB/run -    2.61 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              90101 runs -    11.72 us/run -      354 kB/run -    2.62 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              65528 runs -    16.90 us/run -      622 kB/run -    4.39 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              40955 runs -    28.84 us/run -     1244 kB/run -    8.23 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    16382 runs -    74.75 us/run -     3186 kB/run -   20.32 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    10532 runs -   144.31 us/run -     6372 kB/run -   21.06 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    10532 runs -   143.02 us/run -     6372 kB/run -   21.25 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     5266 runs -   286.72 us/run -    12744 kB/run -   21.20 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     8426 runs -   176.62 us/run -     7965 kB/run -   21.51 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     4214 runs -   351.21 us/run -    15930 kB/run -   21.63 GB/s

original + _ext, MIN(N, M) - best case IMO.
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):               221157 runs -     4.53 us/run -       99 kB/run -    0.78 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):               196584 runs -     5.14 us/run -      199 kB/run -    1.54 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):             147438 runs -     6.80 us/run -      351 kB/run -    2.74 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):             114674 runs -     9.23 us/run -      703 kB/run -    5.19 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              81910 runs -    12.44 us/run -      354 kB/run -    2.72 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              90101 runs -    11.55 us/run -      354 kB/run -    2.66 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              57337 runs -    17.50 us/run -      622 kB/run -    4.85 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):              40955 runs -    28.67 us/run -     1244 kB/run -    8.28 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    16382 runs -    74.53 us/run -     3186 kB/run -   20.38 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    10532 runs -   144.56 us/run -     6372 kB/run -   21.02 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                    10532 runs -   143.35 us/run -     6372 kB/run -   21.20 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     5266 runs -   283.69 us/run -    12744 kB/run -   21.42 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     8426 runs -   176.14 us/run -     7965 kB/run -   21.56 GB/s
  IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1):                     4214 runs -   351.41 us/run -    15930 kB/run -   21.62 GB/s

I guess this comes from wasting threads in _ext kernel(in case of tpitg_0 >= N) series when N is smaller than M(maxTotalThreadsPerThreadgroup) and it gets worse M - N gap is getting bigger.
The Apple official document told me that they have new API spreading threads fit into the given grid. However, it’s supporting from Metal3/Apple4(you can find this in the Metal feature set table with the keyword nonuniform) so I didn’t check it because it narrows runnable devices of llama.cpp.

I applied original + _ext, MIN(N, M) to this PR which is the best from those investigations. Thanks for your review. I could do more optimization from your comment. :-^)


switch (dst->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
case GGML_TYPE_F32: {
pipeline = (is_gt_mttpt ?
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
:
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
} break;
case GGML_TYPE_F16: {
pipeline = (is_gt_mttpt ?
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
:
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
} break;
default: GGML_ABORT("fatal error");
};

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];

[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];

if (is_gt_mttpt) {
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];

const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);

const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);

[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} else {
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
}
} break;
case GGML_OP_UPSCALE:
{
Expand Down Expand Up @@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_POOL_2D:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);

const int32_t * opts = dst->op_params;
enum ggml_op_pool op = opts[0];

id<MTLComputePipelineState> pipeline = nil;
switch (src0t) {
case GGML_TYPE_F32: {
switch(op) {
case GGML_OP_POOL_AVG:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
case GGML_OP_POOL_MAX:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}
} break;
default: GGML_ASSERT(false && "not implemented");
}

const int32_t k0 = opts[1];
const int32_t k1 = opts[2];
const int32_t s0 = opts[3];
const int32_t s1 = opts[4];
const int32_t p0 = opts[5];
const int32_t p1 = opts[6];

const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];

const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];

const int64_t parallel_elements = N * OC * OH * OW;
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
[encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];

[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
Expand Down
178 changes: 178 additions & 0 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;

typedef void (im2col_ext_t)(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
constant int32_t & N,
constant int32_t & KH,
constant int32_t & KW,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);

template <typename T>
kernel void kernel_im2col_ext(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
constant int32_t & N,
constant int32_t & KH,
constant int32_t & KW,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]

const int32_t d = tgpig[0] / CHW;
const int32_t chw = tgpig[0] % CHW;
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int32_t HW = tgpig[0] % KHW;

const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= N) {
return;
}

const int32_t tpitg_1 = HW / KW;
const int32_t tpitg_2 = HW % KW;

const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;

const int32_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);

device T * pdst = (device T *) (dst);

if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}

template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;

kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,
Expand Down Expand Up @@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;

kernel void kernel_pool_2d_max_f32(
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {

if (gid >= parallel_elements) {
return;
}

const int idx = gid;
const int I_HW = IH * IW;
const int O_HW = OH * OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;

device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;

const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);

float res = -INFINITY;

for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * IW + j]);
}
}

o_ptr[cur_oh * OW + cur_ow] = res;
}

kernel void kernel_pool_2d_avg_f32(
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {

if (gid >= parallel_elements) {
return;
}

const int idx = gid;
const int I_HW = IH * IW;
const int O_HW = OH * OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;

device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;

const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (k0 * k1);

float res = 0;

for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * IW + j];
res += cur * scale;
}
}

o_ptr[cur_oh * OW + cur_ow] = res;
}
10 changes: 10 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3316,6 +3316,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));

// test cases for 2D im2col
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));

// sycl backend will limit task global_range < MAX_INT
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
// however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
Expand Down
Loading