| 
1 | 1 | #include "im2col.cuh"  | 
2 | 2 | 
 
  | 
 | 3 | +#define MIN(a, b) (a) < (b) ? (a) : (b)  | 
 | 4 | + | 
 | 5 | +#define MAX_GRIDDIM_Z 65535  | 
 | 6 | + | 
3 | 7 | template <typename T>  | 
4 | 8 | static  __global__ void im2col_kernel(  | 
5 |  | -        const float * x, T * dst, int64_t batch_offset,  | 
6 |  | -        int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,  | 
 | 9 | +        const float * x, T * dst,  | 
 | 10 | +        int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,  | 
 | 11 | +        int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,  | 
7 | 12 |         int s0, int s1, int p0, int p1, int d0, int d1) {  | 
8 | 13 |     const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;  | 
9 |  | -    if (i >= pelements) {  | 
 | 14 | +    if (i >= IC_KH_KW) {  | 
10 | 15 |         return;  | 
11 | 16 |     }  | 
12 | 17 | 
 
  | 
13 |  | -    const int64_t  ksize = OW * KH;  | 
14 |  | -    const int64_t  kx = i / ksize;  | 
15 |  | -    const int64_t  kd = kx * ksize;  | 
16 |  | -    const int64_t  ky = (i - kd) / OW;  | 
17 |  | -    const int64_t  ix = i % OW;  | 
 | 18 | +    const int64_t iic = i / (KH_KW);  | 
 | 19 | +    const int64_t rem = i - iic * KH_KW;  | 
 | 20 | +    const int64_t ikh = rem / KW;  | 
 | 21 | +    const int64_t ikw = rem - ikh * KW;  | 
18 | 22 | 
 
  | 
19 |  | -    const int64_t  oh = blockIdx.y;  | 
20 |  | -    const int64_t  batch = blockIdx.z / IC;  | 
21 |  | -    const int64_t  ic = blockIdx.z % IC;  | 
 | 23 | +    const int64_t  iow = blockIdx.y;  | 
 | 24 | +    for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {  | 
 | 25 | +        const int64_t  in = iz / OH;  | 
 | 26 | +        const int64_t  ioh = iz - in * OH;  | 
22 | 27 | 
 
  | 
23 |  | -    const int64_t iiw = ix * s0 + kx * d0 - p0;  | 
24 |  | -    const int64_t iih = oh * s1 + ky * d1 - p1;  | 
 | 28 | +        const int64_t iiw = iow * s0 + ikw * d0 - p0;  | 
 | 29 | +        const int64_t iih = ioh * s1 + ikh * d1 - p1;  | 
25 | 30 | 
 
  | 
26 |  | -    const int64_t offset_dst =  | 
27 |  | -        ((batch * OH + oh) * OW + ix) * CHW +  | 
28 |  | -        (ic * (KW * KH) + ky * KW + kx);  | 
 | 31 | +        const int64_t offset_dst =  | 
 | 32 | +            ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;  | 
29 | 33 | 
 
  | 
30 |  | -    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {  | 
31 |  | -        dst[offset_dst] = 0.0f;  | 
32 |  | -    } else {  | 
33 |  | -        const int64_t offset_src = ic * offset_delta + batch * batch_offset;  | 
34 |  | -        dst[offset_dst] = x[offset_src + iih * IW + iiw];  | 
 | 34 | +        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {  | 
 | 35 | +            dst[offset_dst] = 0.0f;  | 
 | 36 | +        } else {  | 
 | 37 | +            const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;  | 
 | 38 | +            dst[offset_dst] = x[offset_src + iih * IW + iiw];  | 
 | 39 | +        }  | 
35 | 40 |     }  | 
36 | 41 | }  | 
37 | 42 | 
 
  | 
 | 43 | +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]  | 
38 | 44 | template <typename T>  | 
39 | 45 | static void im2col_cuda(const float * x, T* dst,  | 
40 | 46 |     int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,  | 
41 |  | -    int64_t batch, int64_t batch_offset, int64_t offset_delta,  | 
 | 47 | +    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,  | 
42 | 48 |     int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {  | 
43 |  | -    const int parallel_elements = OW * KW * KH;  | 
44 |  | -    const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;  | 
45 |  | -    dim3 block_nums(num_blocks, OH, batch * IC);  | 
46 |  | -    im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);  | 
 | 49 | +    const int64_t IC_KH_KW = IC * KH * KW;  | 
 | 50 | +    const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;  | 
 | 51 | +    const int64_t N_OH = N * OH;  | 
 | 52 | +    const int64_t KH_KW = KW*KH;  | 
 | 53 | +    dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));  | 
 | 54 | +    im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,  | 
 | 55 | +                                                                                     IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,  | 
 | 56 | +                                                                                     s0, s1, p0, p1, d0, d1);  | 
47 | 57 | }  | 
48 | 58 | 
 
  | 
49 | 59 | static void im2col_cuda_f16(const float * x, half * dst,  | 
50 | 60 |     int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,  | 
51 |  | -    int64_t batch, int64_t batch_offset, int64_t offset_delta,  | 
 | 61 | +    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,  | 
52 | 62 |     int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {  | 
53 | 63 | 
 
  | 
54 |  | -    im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);  | 
 | 64 | +    im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);  | 
55 | 65 | }  | 
56 | 66 | 
 
  | 
57 | 67 | static void im2col_cuda_f32(const float * x, float * dst,  | 
58 | 68 |     int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,  | 
59 |  | -    int64_t batch, int64_t batch_offset, int64_t offset_delta,  | 
 | 69 | +    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,  | 
60 | 70 |     int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {  | 
61 | 71 | 
 
  | 
62 |  | -    im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);  | 
 | 72 | +    im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);  | 
63 | 73 | }  | 
64 | 74 | 
 
  | 
65 | 75 | void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {  | 
@@ -91,13 +101,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {  | 
91 | 101 |     const int64_t OH = is_2D ? dst->ne[2] : 1;  | 
92 | 102 |     const int64_t OW =         dst->ne[1];  | 
93 | 103 | 
 
  | 
94 |  | -    const size_t  delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32  | 
95 |  | -    const int64_t batch        = src1->ne[is_2D ? 3 : 2];  | 
96 |  | -    const size_t  batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32  | 
 | 104 | +    const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32  | 
 | 105 | +    const int64_t N        = src1->ne[is_2D ? 3 : 2];  | 
 | 106 | +    const int64_t IH_IW    = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32  | 
97 | 107 | 
 
  | 
98 | 108 |     if(dst->type == GGML_TYPE_F16) {  | 
99 |  | -        im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);  | 
 | 109 | +        im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);  | 
100 | 110 |     } else {  | 
101 |  | -        im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);  | 
 | 111 | +        im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);  | 
102 | 112 |     }  | 
103 | 113 | }  | 
0 commit comments