Skip to content

Commit 3877608

Browse files
committed
fix passing param as reference
1 parent 4d77287 commit 3877608

File tree

3 files changed

+118
-8
lines changed

3 files changed

+118
-8
lines changed

ggml/src/ggml-cuda/conv2d-implicit.cu

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,22 @@ template <typename T>
2525
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
2626
const T * __restrict__ kernel,
2727
float * __restrict__ output,
28-
const param_t &param) {
28+
const param_t param) {
2929

30-
extern __shared__ __align__(16 * 1024) char smem[];
30+
extern __shared__ __align__(16 * 1024) char smem[];
3131
T *smemweight = reinterpret_cast<T *>(smem);
3232
float *smeminput = reinterpret_cast<float *>(smem + 16 * 1024);
3333

3434
int tx = threadIdx.x;
3535
int bx = blockIdx.x;
3636
int by = blockIdx.y;
3737

38+
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
39+
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
40+
// // printf("param.n=%d\n",param.n);
41+
// }
42+
// __syncthreads();
43+
3844
// Warp tile
3945
const int lane_id = threadIdx.x % 32;
4046
const int warp_id = threadIdx.x / 32;
@@ -85,6 +91,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
8591
}
8692
}
8793
// ldg
94+
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
95+
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
96+
// }
97+
// __syncthreads();
8898
#pragma unroll
8999
for (int i = 0; i < 4; ++i)
90100
{
@@ -282,11 +292,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
282292
}
283293
}
284294
}
285-
286295
}
287296

288297
template <typename T>
289-
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
298+
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
290299
// const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
291300
int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number
292301
int blocky = (P.k + 127) / 128; // blocky number
@@ -300,11 +309,11 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
300309
conv2d_implicit_kernel<T><<<grid, thblock, smem_size, st>>>(X_D, K_D, Y_D, P);
301310
}
302311

303-
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
312+
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) {
304313
conv2d_implicit_cuda<half>(X_D, K_D, Y_D, P, st);
305314
}
306315

307-
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
316+
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) {
308317
conv2d_implicit_cuda<float>(X_D, K_D, Y_D, P, st);
309318
}
310319

@@ -343,9 +352,9 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
343352
const int IC = input->ne[2]; // input_channels
344353
const int OC = kernel->ne[3]; // ouptut_chanles
345354
const int B = input->ne[3]; // n_batches
346-
355+
347356
const int64_t total = B * OC * OH * OW;
348-
// param_t params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
357+
349358
param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW };
350359

351360
if (kernel->type == GGML_TYPE_F16) {

ggml/src/ggml.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
975975
"IM2COL",
976976
"IM2COL_BACK",
977977
"CONV_2D",
978+
"CONV_2D_IMPLICIT",
978979
"CONV_3D",
979980
"CONV_2D_DW",
980981
"CONV_TRANSPOSE_2D",
@@ -1078,6 +1079,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10781079
"im2col(x)",
10791080
"im2col_back(x)",
10801081
"conv_2d(x)",
1082+
"conv_2d_implicit(x)",
10811083
"conv_3d(x)",
10821084
"conv_2d_dw(x)",
10831085
"conv_transpose_2d(x)",

tests/test-backend-ops.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4116,6 +4116,94 @@ struct test_conv_2d : public test_case {
41164116
}
41174117
};
41184118

4119+
// CONV_2D_IMPLICIT
4120+
struct test_conv_2d_implicit : public test_case {
4121+
const std::array<int64_t, 4> ne_input;
4122+
const std::array<int64_t, 4> ne_kernel;
4123+
const ggml_type type_kernel;
4124+
const int stride0;
4125+
const int stride1;
4126+
const int padding0;
4127+
const int padding1;
4128+
const int dilation0;
4129+
const int dilation1;
4130+
// Whether the inputs are contiguous in the channel dim or the width dim
4131+
const bool cwhn;
4132+
4133+
4134+
4135+
std::string vars() override {
4136+
return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
4137+
}
4138+
4139+
double max_nmse_err() override {
4140+
return 5e-4;
4141+
}
4142+
4143+
uint64_t op_flops(ggml_tensor * t) override {
4144+
GGML_UNUSED(t);
4145+
// Just counting matmul costs:
4146+
// KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops
4147+
4148+
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
4149+
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
4150+
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
4151+
};
4152+
4153+
int64_t W = ne_input[0];
4154+
int64_t H = ne_input[1];
4155+
int64_t KW = ne_kernel[0];
4156+
int64_t KH = ne_kernel[1];
4157+
int64_t Cin = ne_kernel[2];
4158+
int64_t Cout = ne_kernel[3];
4159+
int64_t N = ne_input[3];
4160+
int64_t OH = calc_conv_output_size(H, KH, stride0, padding0, dilation0);
4161+
int64_t OW = calc_conv_output_size(W, KW, stride0, padding0, dilation0);
4162+
4163+
int64_t K = Cout;
4164+
int64_t CRS = Cin * KH * KW;
4165+
int64_t NPQ = N * OH * OW;
4166+
4167+
return K * NPQ * (2 * CRS - 1);
4168+
}
4169+
4170+
test_conv_2d_implicit(std::array<int64_t, 4> ne_input = { 64, 64, 16, 1 },
4171+
std::array<int64_t, 4> ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1,
4172+
int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) :
4173+
ne_input(ne_input),
4174+
ne_kernel(ne_kernel),
4175+
type_kernel(type_kernel),
4176+
stride0(stride0),
4177+
stride1(stride1),
4178+
padding0(padding0),
4179+
padding1(padding1),
4180+
dilation0(dilation0),
4181+
dilation1(dilation1),
4182+
cwhn(cwhn) {}
4183+
4184+
ggml_tensor * build_graph(ggml_context * ctx) override {
4185+
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
4186+
ggml_set_name(input, "input");
4187+
4188+
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
4189+
ggml_set_name(kernel, "kernel");
4190+
4191+
if (cwhn) {
4192+
// change memory layout to channel-most-contiguous (CWHN),
4193+
// then permute it back so NE matches the original input
4194+
input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
4195+
input = ggml_permute(ctx, input, 2, 0, 1, 3);
4196+
kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
4197+
kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
4198+
}
4199+
4200+
ggml_tensor * out =
4201+
ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1);
4202+
ggml_set_name(out, "out");
4203+
return out;
4204+
}
4205+
};
4206+
41194207
// GGML_OP_CONV_2D_DW
41204208
struct test_conv_2d_dw : public test_case {
41214209
const std::array<int64_t, 4> ne_input;
@@ -6454,6 +6542,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
64546542
}
64556543
}
64566544

6545+
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
6546+
for (auto act_case : cases) {
6547+
// Direct CONV_2D
6548+
test_cases.emplace_back(new test_conv_2d_implicit(
6549+
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
6550+
{ act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
6551+
kernel_type, 1, 1, 0, 0, 1, 1, false));
6552+
}
6553+
}
6554+
6555+
64576556
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
64586557
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
64596558

0 commit comments

Comments
 (0)