Skip to content

Commit 4d77287

Browse files
committed
Add implicit convolution support for 2D tensors in CPU and CUDA implementations
1 parent 8a58931 commit 4d77287

File tree

4 files changed

+53
-80
lines changed

4 files changed

+53
-80
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18801880
{
18811881
ggml_compute_forward_conv_2d(params, tensor);
18821882
} break;
1883+
case GGML_OP_CONV_2D_IMPLICIT:
1884+
{
1885+
ggml_compute_forward_conv_2d(params, tensor);
1886+
} break;
18831887
case GGML_OP_CONV_3D:
18841888
{
18851889
ggml_compute_forward_conv_3d(params, tensor);
@@ -2256,6 +2260,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22562260
case GGML_OP_IM2COL:
22572261
case GGML_OP_IM2COL_BACK:
22582262
case GGML_OP_CONV_2D:
2263+
case GGML_OP_CONV_2D_IMPLICIT:
22592264
case GGML_OP_CONV_3D:
22602265
case GGML_OP_CONV_2D_DW:
22612266
case GGML_OP_CONV_TRANSPOSE_1D:
@@ -2778,6 +2783,7 @@ struct ggml_cplan ggml_graph_plan(
27782783
}
27792784
} break;
27802785
case GGML_OP_CONV_2D:
2786+
case GGML_OP_CONV_2D_IMPLICIT:
27812787
case GGML_OP_CONV_3D:
27822788
{
27832789
cur = GGML_IM2COL_WORK_SIZE;

ggml/src/ggml-cuda/conv2d-implicit.cu

Lines changed: 40 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,33 @@
11
#include "conv2d-implicit.cuh"
22
#include "convert.cuh"
33

4-
struct conv_params {
5-
const int64_t IW, IH;
6-
const int64_t OW, OH;
7-
const int64_t KW, KH;
8-
const int64_t ST_X, ST_Y;
9-
const int64_t PD_X, PD_Y;
10-
const int64_t DL_X, DL_Y;
11-
const int64_t IC, OC;
12-
const int64_t B;
13-
const int64_t TOTAL;
14-
};
15-
16-
struct kernel_bounds {
17-
int64_t y_min, y_max;
18-
int64_t x_min, x_max;
19-
};
20-
21-
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
22-
return (a > b) ? a : b;
23-
}
24-
25-
__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
26-
return (a < b) ? a : b;
27-
}
28-
29-
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
30-
kernel_bounds bounds;
31-
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
32-
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
33-
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
34-
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
35-
return bounds;
36-
}
37-
38-
__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
39-
int64_t kern_coord,
40-
int64_t stride,
41-
int64_t dilation,
42-
int64_t padding) {
43-
return out_coord * stride + kern_coord * dilation - padding;
44-
}
4+
typedef struct{
5+
unsigned int n; //batch szie
6+
unsigned int c; //channel number
7+
unsigned int h; //height
8+
unsigned int w; //width
9+
unsigned int k; //number of filters
10+
unsigned int r; //filter height
11+
unsigned int s; //filter width
12+
unsigned int u; //stride height
13+
unsigned int v; //stride width
14+
unsigned int p; //padding height
15+
unsigned int q; //padding width
16+
unsigned int d_h; //dilation height
17+
unsigned int d_w; //dilation width
18+
unsigned int Oh; //output height
19+
unsigned int Ow; //output width
20+
} param_t;
4521

46-
struct whcn_layout {
47-
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
48-
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
49-
}
50-
51-
__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
52-
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
53-
}
5422

55-
__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
56-
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
57-
}
5823

59-
__device__ static void unpack_indices(int64_t global_idx,
60-
const conv_params & P,
61-
int64_t & n,
62-
int64_t & c,
63-
int64_t & out_y,
64-
int64_t & out_x) {
65-
out_x = global_idx % P.OW;
66-
out_y = (global_idx / P.OW) % P.OH;
67-
c = (global_idx / (P.OW * P.OH)) % P.OC;
68-
n = global_idx / (P.OW * P.OH * P.OC);
69-
}
70-
};
71-
72-
template <typename T, typename Layout>
24+
template <typename T>
7325
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
7426
const T * __restrict__ kernel,
7527
float * __restrict__ output,
76-
const conv_params P) {
28+
const param_t &param) {
7729

78-
__shared__ __align__(16 * 1024) char smem[24 * 1024];
30+
extern __shared__ __align__(16 * 1024) char smem[];
7931
T *smemweight = reinterpret_cast<T *>(smem);
8032
float *smeminput = reinterpret_cast<float *>(smem + 16 * 1024);
8133

@@ -151,8 +103,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
151103
#pragma unroll
152104
for (int i = 0; i < 4; ++i)
153105
{
154-
int curH = posh_ori[i] + curR; // input h
155-
int curW = posw_ori[i] + curS; // input w
106+
int curH = posh_ori[i] + curR * param.d_h; // input h
107+
int curW = posw_ori[i] + curS * param.d_w; // input w
156108
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
157109
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h)
158110
{
@@ -210,8 +162,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
210162
#pragma unroll
211163
for (int i = 0; i < 4; ++i)
212164
{
213-
int curH = posh_ori[i] + curR; // input h
214-
int curW = posw_ori[i] + curS; // input w
165+
int curH = posh_ori[i] + curR * param.d_h; // input h
166+
int curW = posw_ori[i] + curS * param.d_w; // input w
215167
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
216168
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h)
217169
{
@@ -334,16 +286,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
334286
}
335287

336288
template <typename T>
337-
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
338-
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
339-
conv2d_implicit_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
289+
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
290+
// const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
291+
int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number
292+
int blocky = (P.k + 127) / 128; // blocky number
293+
int blockz = P.n; // blockz number
294+
int threadx = CUDA_CONV2D_IMPLICT_BLOCK_SIZE; // threadx number per block
295+
int thready = 1; // thready number per block
296+
int threadz = 1; // threadz number per block
297+
dim3 thblock(threadx, thready, threadz);
298+
dim3 grid(blockx, blocky, blockz);
299+
int smem_size = 24 * 1024;
300+
conv2d_implicit_kernel<T><<<grid, thblock, smem_size, st>>>(X_D, K_D, Y_D, P);
340301
}
341302

342-
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
303+
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
343304
conv2d_implicit_cuda<half>(X_D, K_D, Y_D, P, st);
344305
}
345306

346-
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
307+
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
347308
conv2d_implicit_cuda<float>(X_D, K_D, Y_D, P, st);
348309
}
349310

@@ -384,7 +345,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
384345
const int B = input->ne[3]; // n_batches
385346

386347
const int64_t total = B * OC * OH * OW;
387-
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
348+
// param_t params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
349+
param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW };
388350

389351
if (kernel->type == GGML_TYPE_F16) {
390352
conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "ggml-cuda/concat.cuh"
1414
#include "ggml-cuda/conv-transpose-1d.cuh"
1515
#include "ggml-cuda/conv2d.cuh"
16+
#include "ggml-cuda/conv2d-implicit.cuh"
1617
#include "ggml-cuda/conv2d-dw.cuh"
1718
#include "ggml-cuda/conv2d-transpose.cuh"
1819
#include "ggml-cuda/convert.cuh"
@@ -2455,6 +2456,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24552456
case GGML_OP_CONV_2D:
24562457
ggml_cuda_op_conv2d(ctx, dst);
24572458
break;
2459+
case GGML_OP_CONV_2D_IMPLICIT:
2460+
ggml_cuda_op_conv2d_implicit(ctx, dst);
2461+
break;
24582462
case GGML_OP_CONV_2D_DW:
24592463
ggml_cuda_op_conv2d_dw(ctx, dst);
24602464
break;
@@ -3560,6 +3564,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
35603564
}
35613565
case GGML_OP_IM2COL:
35623566
case GGML_OP_CONV_2D:
3567+
case GGML_OP_CONV_2D_IMPLICIT:
35633568
case GGML_OP_CONV_2D_DW:
35643569
case GGML_OP_CONV_TRANSPOSE_2D:
35653570
case GGML_OP_POOL_2D:

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10181018
"GLU",
10191019
};
10201020

1021-
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
1021+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
10221022

10231023
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10241024
"none",
@@ -1121,7 +1121,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11211121
"glu(x)",
11221122
};
11231123

1124-
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
1124+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
11251125

11261126
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11271127

0 commit comments

Comments
 (0)