-
Notifications
You must be signed in to change notification settings - Fork 13.5k
ggml:metal Add POOL2D op and fix IM2COL in Metal backend for running MobileVLM_V2. #9943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
b4d3c16
adbec7f
1467a7a
e81462d
0084847
bd86c4c
bb9949b
746e79e
3c2b87d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte | |||||||||
| GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, | ||||||||||
| GGML_METAL_KERNEL_TYPE_IM2COL_F16, | ||||||||||
| GGML_METAL_KERNEL_TYPE_IM2COL_F32, | ||||||||||
| GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, | ||||||||||
| GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, | ||||||||||
| GGML_METAL_KERNEL_TYPE_UPSCALE_F32, | ||||||||||
| GGML_METAL_KERNEL_TYPE_PAD_F32, | ||||||||||
| GGML_METAL_KERNEL_TYPE_ARANGE_F32, | ||||||||||
|
|
@@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte | |||||||||
| GGML_METAL_KERNEL_TYPE_SIN, | ||||||||||
| GGML_METAL_KERNEL_TYPE_COS, | ||||||||||
| GGML_METAL_KERNEL_TYPE_SUM_ROWS, | ||||||||||
| GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, | ||||||||||
| GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, | ||||||||||
|
|
||||||||||
| GGML_METAL_KERNEL_TYPE_COUNT | ||||||||||
| }; | ||||||||||
|
|
@@ -685,6 +689,8 @@ @implementation GGMLMetalClass | |||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); | ||||||||||
|
|
@@ -716,6 +722,8 @@ @implementation GGMLMetalClass | |||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true); | ||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true); | ||||||||||
|
||||||||||
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true); | |
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true); | |
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); | |
| GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied in 008484799146274784a5df692fd7b7508805be83
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to keep just the _ext variant of the kernel? Does the old kernel have a significant advantage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary: To keep backward compatibility and performance in smaller computations, I suggest you keep the old kernel.
details
I’ve done some investigations after you gave me this comment.
The new kernel(a.k.a _ext) has performance degradation in small size of N:
$ test-backend-ops perf -o IM2COL -b Metal
NOTE: 9 x 113 is 1017, 9 x 114 is 1026 which exceeds limits.
original + _ext, M
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 229348 runs - 4.38 us/run - 99 kB/run - 0.77 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 196584 runs - 5.24 us/run - 199 kB/run - 1.51 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 147438 runs - 6.82 us/run - 351 kB/run - 2.73 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 106483 runs - 9.67 us/run - 703 kB/run - 5.34 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 18.79 us/run - 354 kB/run - 2.57 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 18.74 us/run - 354 kB/run - 2.58 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 49146 runs - 21.43 us/run - 622 kB/run - 4.61 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 32764 runs - 35.12 us/run - 1244 kB/run - 8.45 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 16382 runs - 74.60 us/run - 3186 kB/run - 20.37 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 144.71 us/run - 6372 kB/run - 21.00 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 143.44 us/run - 6372 kB/run - 21.18 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 5266 runs - 283.55 us/run - 12744 kB/run - 21.43 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 8426 runs - 175.97 us/run - 7965 kB/run - 21.58 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 4214 runs - 352.93 us/run - 15930 kB/run - 21.52 GB/s
only _ext, M - worst case.
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 73719 runs - 14.48 us/run - 99 kB/run - 0.73 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 49146 runs - 24.31 us/run - 199 kB/run - 1.30 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 17.90 us/run - 351 kB/run - 2.68 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 40955 runs - 28.89 us/run - 703 kB/run - 4.64 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 18.54 us/run - 354 kB/run - 2.61 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 18.06 us/run - 354 kB/run - 2.68 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 20.35 us/run - 622 kB/run - 4.17 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 32764 runs - 34.79 us/run - 1244 kB/run - 8.53 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 16382 runs - 74.19 us/run - 3186 kB/run - 20.48 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 144.72 us/run - 6372 kB/run - 21.00 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 142.86 us/run - 6372 kB/run - 21.27 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 5266 runs - 283.70 us/run - 12744 kB/run - 21.42 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 8426 runs - 175.84 us/run - 7965 kB/run - 21.60 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 4214 runs - 350.67 us/run - 15930 kB/run - 21.66 GB/s
only_ext, MIN(N, M)
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 131056 runs - 7.78 us/run - 99 kB/run - 0.76 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.24 us/run - 199 kB/run - 1.54 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.58 us/run - 351 kB/run - 2.63 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 17.85 us/run - 703 kB/run - 5.37 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.80 us/run - 354 kB/run - 2.61 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.72 us/run - 354 kB/run - 2.62 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 65528 runs - 16.90 us/run - 622 kB/run - 4.39 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 40955 runs - 28.84 us/run - 1244 kB/run - 8.23 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 16382 runs - 74.75 us/run - 3186 kB/run - 20.32 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 144.31 us/run - 6372 kB/run - 21.06 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 143.02 us/run - 6372 kB/run - 21.25 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 5266 runs - 286.72 us/run - 12744 kB/run - 21.20 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 8426 runs - 176.62 us/run - 7965 kB/run - 21.51 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 4214 runs - 351.21 us/run - 15930 kB/run - 21.63 GB/s
original + _ext, MIN(N, M) - best case IMO.
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,32],ne_kernel=[3,3,1,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 221157 runs - 4.53 us/run - 99 kB/run - 0.78 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,32],ne_kernel=[3,3,2,32],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 196584 runs - 5.14 us/run - 199 kB/run - 1.54 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,113],ne_kernel=[3,3,1,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 147438 runs - 6.80 us/run - 351 kB/run - 2.74 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,113],ne_kernel=[3,3,2,113],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 114674 runs - 9.23 us/run - 703 kB/run - 5.19 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 81910 runs - 12.44 us/run - 354 kB/run - 2.72 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,114],ne_kernel=[3,3,1,114],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 90101 runs - 11.55 us/run - 354 kB/run - 2.66 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,200],ne_kernel=[3,3,1,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 57337 runs - 17.50 us/run - 622 kB/run - 4.85 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,200],ne_kernel=[3,3,2,200],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 40955 runs - 28.67 us/run - 1244 kB/run - 8.28 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,1024],ne_kernel=[3,3,1,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 16382 runs - 74.53 us/run - 3186 kB/run - 20.38 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,1024],ne_kernel=[3,3,2,1024],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 144.56 us/run - 6372 kB/run - 21.02 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2048],ne_kernel=[3,3,1,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 10532 runs - 143.35 us/run - 6372 kB/run - 21.20 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2048],ne_kernel=[3,3,2,2048],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 5266 runs - 283.69 us/run - 12744 kB/run - 21.42 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 8426 runs - 176.14 us/run - 7965 kB/run - 21.56 GB/s
IM2COL(type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1): 4214 runs - 351.41 us/run - 15930 kB/run - 21.62 GB/s
I guess this comes from wasting threads in _ext kernel(in case of tpitg_0 >= N) series when N is smaller than M(maxTotalThreadsPerThreadgroup) and it gets worse M - N gap is getting bigger.
The Apple official document told me that they have new API spreading threads fit into the given grid. However, it’s supporting from Metal3/Apple4(you can find this in the Metal feature set table with the keyword nonuniform) so I didn’t check it because it narrows runnable devices of llama.cpp.
I applied original + _ext, MIN(N, M) to this PR which is the best from those investigations. Thanks for your review. I could do more optimization from your comment. :-^)
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1933,6 +1933,85 @@ kernel void kernel_im2col( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| typedef void (im2col_ext_t)( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device const float * x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device char * dst, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & ofs0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & ofs1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & IW, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & IH, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & CHW, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & s0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & s1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & p0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & p1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & d0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & d1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & N, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & KH, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & KW, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint3 tgpig[[threadgroup_position_in_grid]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint3 tgpg[[threadgroups_per_grid]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint3 tpitg[[thread_position_in_threadgroup]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint3 ntg[[threads_per_threadgroup]]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel void kernel_im2col_ext( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device const float * x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device char * dst, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & ofs0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & ofs1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & IW, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & IH, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & CHW, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & s0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & s1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & p0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & p1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & d0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & d1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & N, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & KH, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t & KW, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint3 tgpig[[threadgroup_position_in_grid]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint3 tpitg[[thread_position_in_threadgroup]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t d = tgpig[0] / CHW; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t chw = tgpig[0] % CHW; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t HW = tgpig[0] % KHW; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (tpitg_0 >= N) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t tpitg_1 = HW / KW; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t tpitg_2 = HW % KW; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t offset_dst = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device T * pdst = (device T *) (dst); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pdst[offset_dst] = 0.0f; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pdst[offset_dst] = x[offset_src + iih * IW + iiw]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel void kernel_upscale_f32( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device const char * src0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device char * dst, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -6372,3 +6451,98 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel void kernel_max_pool_2d_f32( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device const float* src0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device float* dst, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t& k0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t& k1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t& s0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t& s1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t& p0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int32_t& p1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int64_t& IH, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int64_t& IW, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int64_t& OH, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int64_t& OW, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constant int64_t& parallel_elements, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint gid[[thread_position_in_grid]]) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device const float* src0, | |
| device float* dst, | |
| constant int32_t& k0, | |
| constant int32_t& k1, | |
| constant int32_t& s0, | |
| constant int32_t& s1, | |
| constant int32_t& p0, | |
| constant int32_t& p1, | |
| constant int64_t& IH, | |
| constant int64_t& IW, | |
| constant int64_t& OH, | |
| constant int64_t& OW, | |
| constant int64_t& parallel_elements, | |
| uint gid[[thread_position_in_grid]]) { | |
| device const float * src0, | |
| device float * dst, | |
| constant int32_t & k0, | |
| constant int32_t & k1, | |
| constant int32_t & s0, | |
| constant int32_t & s1, | |
| constant int32_t & p0, | |
| constant int32_t & p1, | |
| constant int64_t & IH, | |
| constant int64_t & IW, | |
| constant int64_t & OH, | |
| constant int64_t & OW, | |
| constant int64_t & parallel_elements, | |
| uint gid[[thread_position_in_grid]]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied in 008484799146274784a5df692fd7b7508805be83
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied in 008484799146274784a5df692fd7b7508805be83