-
Notifications
You must be signed in to change notification settings - Fork 13.5k
ggml:metal Add POOL2D op and fix IM2COL in Metal backend for running MobileVLM_V2. #9943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
b4d3c16
adbec7f
1467a7a
e81462d
0084847
bd86c4c
bb9949b
746e79e
3c2b87d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte | |
| GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, | ||
| GGML_METAL_KERNEL_TYPE_IM2COL_F16, | ||
| GGML_METAL_KERNEL_TYPE_IM2COL_F32, | ||
| GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, | ||
| GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, | ||
| GGML_METAL_KERNEL_TYPE_UPSCALE_F32, | ||
| GGML_METAL_KERNEL_TYPE_PAD_F32, | ||
| GGML_METAL_KERNEL_TYPE_ARANGE_F32, | ||
|
|
@@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte | |
| GGML_METAL_KERNEL_TYPE_SIN, | ||
| GGML_METAL_KERNEL_TYPE_COS, | ||
| GGML_METAL_KERNEL_TYPE_SUM_ROWS, | ||
| GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, | ||
| GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, | ||
|
|
||
| GGML_METAL_KERNEL_TYPE_COUNT | ||
| }; | ||
|
|
@@ -685,6 +689,8 @@ @implementation GGMLMetalClass | |
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); | ||
|
|
@@ -716,6 +722,8 @@ @implementation GGMLMetalClass | |
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, avg_pool_2d_f32, true); | ||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, max_pool_2d_f32, true); | ||
| } | ||
|
|
||
| [metal_library release]; | ||
|
|
@@ -844,8 +852,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | |
| case GGML_OP_IM2COL: | ||
| return op->src[0]->type == GGML_TYPE_F16; | ||
| case GGML_OP_POOL_1D: | ||
| case GGML_OP_POOL_2D: | ||
| return false; | ||
| case GGML_OP_POOL_2D: | ||
| return true; | ||
| case GGML_OP_UPSCALE: | ||
| case GGML_OP_PAD: | ||
| case GGML_OP_ARANGE: | ||
|
|
@@ -2574,11 +2583,24 @@ static void ggml_metal_encode_node( | |
| const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; | ||
| const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; | ||
|
|
||
| id<MTLComputePipelineState> pipeline = nil; | ||
| id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; | ||
| const uint64_t M = pipeline.maxTotalThreadsPerThreadgroup; | ||
|
|
||
| const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to keep just the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Summary: To keep backward compatibility and performance in smaller computations, I suggest you keep the old kernel. detailsI’ve done some investigations after you gave me this comment. I guess this comes from wasting threads in _ext kernel(in case of I applied |
||
|
|
||
| switch (dst->type) { | ||
| case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; | ||
| case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; | ||
| case GGML_TYPE_F32: { | ||
| pipeline = (is_gt_mttpt ? | ||
| ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline | ||
| : | ||
| ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline); | ||
| } break; | ||
| case GGML_TYPE_F16: { | ||
| pipeline = (is_gt_mttpt ? | ||
| ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline | ||
| : | ||
| ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline); | ||
| } break; | ||
| default: GGML_ABORT("fatal error"); | ||
| }; | ||
|
|
||
|
|
@@ -2597,7 +2619,16 @@ static void ggml_metal_encode_node( | |
| [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; | ||
| [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; | ||
|
|
||
| [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; | ||
| if (is_gt_mttpt) { | ||
| [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; | ||
| [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14]; | ||
| [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15]; | ||
|
|
||
| const int64_t D = N / M + (N % M > 0 ? 1 : 0); | ||
| [encoder dispatchThreadgroups:MTLSizeMake(D * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(MIN((uint64_t)N, M), 1, 1)]; | ||
| } else { | ||
| [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; | ||
| } | ||
| } break; | ||
| case GGML_OP_UPSCALE: | ||
| { | ||
|
|
@@ -3001,6 +3032,63 @@ static void ggml_metal_encode_node( | |
|
|
||
| [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||
| } break; | ||
| case GGML_OP_POOL_2D: | ||
| { | ||
| GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); | ||
|
|
||
| const int32_t* opts = dst->op_params; | ||
| enum ggml_op_pool op = opts[0]; | ||
|
|
||
| id<MTLComputePipelineState> pipeline = nil; | ||
| switch (src0t) { | ||
| case GGML_TYPE_F32: { | ||
| switch(op) { | ||
| case GGML_OP_POOL_AVG: | ||
| pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; | ||
| case GGML_OP_POOL_MAX: | ||
| pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; | ||
| default: GGML_ASSERT(false && "not implemented"); | ||
| } | ||
| } break; | ||
| default: GGML_ASSERT(false && "not implemented"); | ||
| } | ||
|
|
||
| const int32_t k0 = opts[1]; | ||
| const int32_t k1 = opts[2]; | ||
| const int32_t s0 = opts[3]; | ||
| const int32_t s1 = opts[4]; | ||
| const int32_t p0 = opts[5]; | ||
| const int32_t p1 = opts[6]; | ||
|
|
||
| const int64_t IH = src0->ne[1]; | ||
| const int64_t IW = src0->ne[0]; | ||
|
|
||
| const int64_t N = dst->ne[3]; | ||
| const int64_t OC = dst->ne[2]; | ||
| const int64_t OH = dst->ne[1]; | ||
| const int64_t OW = dst->ne[0]; | ||
|
|
||
| const int64_t parallel_elements = N * OC * OH * OW; | ||
| const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); | ||
| const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; | ||
|
|
||
| [encoder setComputePipelineState:pipeline]; | ||
| [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||
| [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; | ||
| [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; | ||
| [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3]; | ||
| [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4]; | ||
| [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5]; | ||
| [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6]; | ||
| [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7]; | ||
| [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8]; | ||
| [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9]; | ||
| [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10]; | ||
| [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11]; | ||
| [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12]; | ||
|
|
||
| [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; | ||
| } break; | ||
| default: | ||
| { | ||
| GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel names also need to be updated to follow the max-prefix naming convention:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad.
applied in bb9949b