|
1 | 1 | #include "im2col.cuh"
|
2 | 2 |
|
| 3 | +#define MIN(a, b) (a) < (b) ? (a) : (b) |
| 4 | + |
| 5 | +#define MAX_GRIDDIM_Z 65535 |
| 6 | + |
3 | 7 | template <typename T>
|
4 | 8 | static __global__ void im2col_kernel(
|
5 |
| - const float * x, T * dst, int64_t batch_offset, |
6 |
| - int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, |
| 9 | + const float * x, T * dst, |
| 10 | + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, |
| 11 | + int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW, |
7 | 12 | int s0, int s1, int p0, int p1, int d0, int d1) {
|
8 | 13 | const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
|
9 |
| - if (i >= pelements) { |
| 14 | + if (i >= IC_KH_KW) { |
10 | 15 | return;
|
11 | 16 | }
|
12 | 17 |
|
13 |
| - const int64_t ksize = OW * KH; |
14 |
| - const int64_t kx = i / ksize; |
15 |
| - const int64_t kd = kx * ksize; |
16 |
| - const int64_t ky = (i - kd) / OW; |
17 |
| - const int64_t ix = i % OW; |
| 18 | + const int64_t iic = i / (KH_KW); |
| 19 | + const int64_t rem = i - iic * KH_KW; |
| 20 | + const int64_t ikh = rem / KW; |
| 21 | + const int64_t ikw = rem - ikh * KW; |
18 | 22 |
|
19 |
| - const int64_t oh = blockIdx.y; |
20 |
| - const int64_t batch = blockIdx.z / IC; |
21 |
| - const int64_t ic = blockIdx.z % IC; |
| 23 | + const int64_t iow = blockIdx.y; |
| 24 | + for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { |
| 25 | + const int64_t in = iz / OH; |
| 26 | + const int64_t ioh = iz - in * OH; |
22 | 27 |
|
23 |
| - const int64_t iiw = ix * s0 + kx * d0 - p0; |
24 |
| - const int64_t iih = oh * s1 + ky * d1 - p1; |
| 28 | + const int64_t iiw = iow * s0 + ikw * d0 - p0; |
| 29 | + const int64_t iih = ioh * s1 + ikh * d1 - p1; |
25 | 30 |
|
26 |
| - const int64_t offset_dst = |
27 |
| - ((batch * OH + oh) * OW + ix) * CHW + |
28 |
| - (ic * (KW * KH) + ky * KW + kx); |
| 31 | + const int64_t offset_dst = |
| 32 | + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; |
29 | 33 |
|
30 |
| - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { |
31 |
| - dst[offset_dst] = 0.0f; |
32 |
| - } else { |
33 |
| - const int64_t offset_src = ic * offset_delta + batch * batch_offset; |
34 |
| - dst[offset_dst] = x[offset_src + iih * IW + iiw]; |
| 34 | + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { |
| 35 | + dst[offset_dst] = 0.0f; |
| 36 | + } else { |
| 37 | + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; |
| 38 | + dst[offset_dst] = x[offset_src + iih * IW + iiw]; |
| 39 | + } |
35 | 40 | }
|
36 | 41 | }
|
37 | 42 |
|
| 43 | +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] |
38 | 44 | template <typename T>
|
39 | 45 | static void im2col_cuda(const float * x, T* dst,
|
40 | 46 | int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
|
41 |
| - int64_t batch, int64_t batch_offset, int64_t offset_delta, |
| 47 | + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, |
42 | 48 | int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
43 |
| - const int parallel_elements = OW * KW * KH; |
44 |
| - const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; |
45 |
| - dim3 block_nums(num_blocks, OH, batch * IC); |
46 |
| - im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); |
| 49 | + const int64_t IC_KH_KW = IC * KH * KW; |
| 50 | + const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; |
| 51 | + const int64_t N_OH = N * OH; |
| 52 | + const int64_t KH_KW = KW*KH; |
| 53 | + dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); |
| 54 | + im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH, |
| 55 | + IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, |
| 56 | + s0, s1, p0, p1, d0, d1); |
47 | 57 | }
|
48 | 58 |
|
49 | 59 | static void im2col_cuda_f16(const float * x, half * dst,
|
50 | 60 | int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
|
51 |
| - int64_t batch, int64_t batch_offset, int64_t offset_delta, |
| 61 | + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, |
52 | 62 | int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
53 | 63 |
|
54 |
| - im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream); |
| 64 | + im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); |
55 | 65 | }
|
56 | 66 |
|
57 | 67 | static void im2col_cuda_f32(const float * x, float * dst,
|
58 | 68 | int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
|
59 |
| - int64_t batch, int64_t batch_offset, int64_t offset_delta, |
| 69 | + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, |
60 | 70 | int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
61 | 71 |
|
62 |
| - im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream); |
| 72 | + im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); |
63 | 73 | }
|
64 | 74 |
|
65 | 75 | void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
@@ -91,13 +101,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
91 | 101 | const int64_t OH = is_2D ? dst->ne[2] : 1;
|
92 | 102 | const int64_t OW = dst->ne[1];
|
93 | 103 |
|
94 |
| - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 |
95 |
| - const int64_t batch = src1->ne[is_2D ? 3 : 2]; |
96 |
| - const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 |
| 104 | + const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 |
| 105 | + const int64_t N = src1->ne[is_2D ? 3 : 2]; |
| 106 | + const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 |
97 | 107 |
|
98 | 108 | if(dst->type == GGML_TYPE_F16) {
|
99 |
| - im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); |
| 109 | + im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); |
100 | 110 | } else {
|
101 |
| - im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); |
| 111 | + im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); |
102 | 112 | }
|
103 | 113 | }
|
0 commit comments