|
1 | 1 | #include "conv2d-implicit.cuh" |
2 | 2 | #include "convert.cuh" |
3 | 3 |
|
4 | | -struct conv_params { |
5 | | - const int64_t IW, IH; |
6 | | - const int64_t OW, OH; |
7 | | - const int64_t KW, KH; |
8 | | - const int64_t ST_X, ST_Y; |
9 | | - const int64_t PD_X, PD_Y; |
10 | | - const int64_t DL_X, DL_Y; |
11 | | - const int64_t IC, OC; |
12 | | - const int64_t B; |
13 | | - const int64_t TOTAL; |
14 | | -}; |
15 | | - |
16 | | -struct kernel_bounds { |
17 | | - int64_t y_min, y_max; |
18 | | - int64_t x_min, x_max; |
19 | | -}; |
20 | | - |
21 | | -__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { |
22 | | - return (a > b) ? a : b; |
23 | | -} |
24 | | - |
25 | | -__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { |
26 | | - return (a < b) ? a : b; |
27 | | -} |
28 | | - |
29 | | -__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { |
30 | | - kernel_bounds bounds; |
31 | | - bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); |
32 | | - bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); |
33 | | - bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); |
34 | | - bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); |
35 | | - return bounds; |
36 | | -} |
37 | | - |
38 | | -__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, |
39 | | - int64_t kern_coord, |
40 | | - int64_t stride, |
41 | | - int64_t dilation, |
42 | | - int64_t padding) { |
43 | | - return out_coord * stride + kern_coord * dilation - padding; |
44 | | -} |
| 4 | +typedef struct{ |
| 5 | + unsigned int n; //batch szie |
| 6 | + unsigned int c; //channel number |
| 7 | + unsigned int h; //height |
| 8 | + unsigned int w; //width |
| 9 | + unsigned int k; //number of filters |
| 10 | + unsigned int r; //filter height |
| 11 | + unsigned int s; //filter width |
| 12 | + unsigned int u; //stride height |
| 13 | + unsigned int v; //stride width |
| 14 | + unsigned int p; //padding height |
| 15 | + unsigned int q; //padding width |
| 16 | + unsigned int d_h; //dilation height |
| 17 | + unsigned int d_w; //dilation width |
| 18 | + unsigned int Oh; //output height |
| 19 | + unsigned int Ow; //output width |
| 20 | +} param_t; |
45 | 21 |
|
46 | | -struct whcn_layout { |
47 | | - __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { |
48 | | - return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; |
49 | | - } |
50 | | - |
51 | | - __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { |
52 | | - return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; |
53 | | - } |
54 | 22 |
|
55 | | - __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { |
56 | | - return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; |
57 | | - } |
58 | 23 |
|
59 | | - __device__ static void unpack_indices(int64_t global_idx, |
60 | | - const conv_params & P, |
61 | | - int64_t & n, |
62 | | - int64_t & c, |
63 | | - int64_t & out_y, |
64 | | - int64_t & out_x) { |
65 | | - out_x = global_idx % P.OW; |
66 | | - out_y = (global_idx / P.OW) % P.OH; |
67 | | - c = (global_idx / (P.OW * P.OH)) % P.OC; |
68 | | - n = global_idx / (P.OW * P.OH * P.OC); |
69 | | - } |
70 | | -}; |
71 | | - |
72 | | -template <typename T, typename Layout> |
| 24 | +template <typename T> |
73 | 25 | static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, |
74 | 26 | const T * __restrict__ kernel, |
75 | 27 | float * __restrict__ output, |
76 | | - const conv_params P) { |
| 28 | + const param_t ¶m) { |
77 | 29 |
|
78 | | - __shared__ __align__(16 * 1024) char smem[24 * 1024]; |
| 30 | + extern __shared__ __align__(16 * 1024) char smem[]; |
79 | 31 | T *smemweight = reinterpret_cast<T *>(smem); |
80 | 32 | float *smeminput = reinterpret_cast<float *>(smem + 16 * 1024); |
81 | 33 |
|
@@ -151,8 +103,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, |
151 | 103 | #pragma unroll |
152 | 104 | for (int i = 0; i < 4; ++i) |
153 | 105 | { |
154 | | - int curH = posh_ori[i] + curR; // input h |
155 | | - int curW = posw_ori[i] + curS; // input w |
| 106 | + int curH = posh_ori[i] + curR * param.d_h; // input h |
| 107 | + int curW = posw_ori[i] + curS * param.d_w; // input w |
156 | 108 | int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; |
157 | 109 | if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) |
158 | 110 | { |
@@ -210,8 +162,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, |
210 | 162 | #pragma unroll |
211 | 163 | for (int i = 0; i < 4; ++i) |
212 | 164 | { |
213 | | - int curH = posh_ori[i] + curR; // input h |
214 | | - int curW = posw_ori[i] + curS; // input w |
| 165 | + int curH = posh_ori[i] + curR * param.d_h; // input h |
| 166 | + int curW = posw_ori[i] + curS * param.d_w; // input w |
215 | 167 | int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; |
216 | 168 | if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) |
217 | 169 | { |
@@ -334,16 +286,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, |
334 | 286 | } |
335 | 287 |
|
336 | 288 | template <typename T> |
337 | | -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { |
338 | | - const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; |
339 | | - conv2d_implicit_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P); |
| 289 | +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t &P, cudaStream_t st) { |
| 290 | + // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; |
| 291 | + int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number |
| 292 | + int blocky = (P.k + 127) / 128; // blocky number |
| 293 | + int blockz = P.n; // blockz number |
| 294 | + int threadx = CUDA_CONV2D_IMPLICT_BLOCK_SIZE; // threadx number per block |
| 295 | + int thready = 1; // thready number per block |
| 296 | + int threadz = 1; // threadz number per block |
| 297 | + dim3 thblock(threadx, thready, threadz); |
| 298 | + dim3 grid(blockx, blocky, blockz); |
| 299 | + int smem_size = 24 * 1024; |
| 300 | + conv2d_implicit_kernel<T><<<grid, thblock, smem_size, st>>>(X_D, K_D, Y_D, P); |
340 | 301 | } |
341 | 302 |
|
342 | | -static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { |
| 303 | +static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t &P, cudaStream_t st) { |
343 | 304 | conv2d_implicit_cuda<half>(X_D, K_D, Y_D, P, st); |
344 | 305 | } |
345 | 306 |
|
346 | | -static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { |
| 307 | +static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t &P, cudaStream_t st) { |
347 | 308 | conv2d_implicit_cuda<float>(X_D, K_D, Y_D, P, st); |
348 | 309 | } |
349 | 310 |
|
@@ -384,7 +345,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * |
384 | 345 | const int B = input->ne[3]; // n_batches |
385 | 346 |
|
386 | 347 | const int64_t total = B * OC * OH * OW; |
387 | | - conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; |
| 348 | + // param_t params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; |
| 349 | + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW }; |
388 | 350 |
|
389 | 351 | if (kernel->type == GGML_TYPE_F16) { |
390 | 352 | conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); |
|
0 commit comments