Skip to content

Commit 67bdc9d

Browse files
am17anMinh141120
authored andcommitted
CUDA: add conv_2d_transpose (ggml-org#14287)
* CUDA: add conv_2d_transpose * remove direct include of cuda_fp16 * Review: add brackets for readability, remove ggml_set_param and add asserts
1 parent 7eb4e7f commit 67bdc9d

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "ggml-cuda/concat.cuh"
1313
#include "ggml-cuda/conv-transpose-1d.cuh"
1414
#include "ggml-cuda/conv2d-dw.cuh"
15+
#include "ggml-cuda/conv2d-transpose.cuh"
1516
#include "ggml-cuda/convert.cuh"
1617
#include "ggml-cuda/count-equal.cuh"
1718
#include "ggml-cuda/cpy.cuh"
@@ -2381,6 +2382,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23812382
case GGML_OP_CONV_2D_DW:
23822383
ggml_cuda_op_conv2d_dw(ctx, dst);
23832384
break;
2385+
case GGML_OP_CONV_TRANSPOSE_2D:
2386+
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2387+
break;
23842388
case GGML_OP_CONV_TRANSPOSE_1D:
23852389
ggml_cuda_op_conv_transpose_1d(ctx,dst);
23862390
break;
@@ -3307,6 +3311,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33073311
}
33083312
case GGML_OP_IM2COL:
33093313
case GGML_OP_CONV_2D_DW:
3314+
case GGML_OP_CONV_TRANSPOSE_2D:
33103315
case GGML_OP_POOL_2D:
33113316
case GGML_OP_SUM:
33123317
case GGML_OP_SUM_ROWS:

tests/test-backend-ops.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2984,6 +2984,35 @@ struct test_conv_transpose_2d : public test_case {
29842984
}
29852985
};
29862986

2987+
// GGML_OP_CONV_TRANSPOSE_2D
2988+
struct test_conv_transpose_2d : public test_case {
2989+
const std::array<int64_t, 4> ne_input;
2990+
const std::array<int64_t, 4> ne_kernel;
2991+
const int stride;
2992+
2993+
std::string vars() override {
2994+
return VARS_TO_STR3(ne_input, ne_kernel, stride);
2995+
}
2996+
2997+
test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
2998+
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
2999+
int stride = 1)
3000+
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
3001+
3002+
ggml_tensor * build_graph(ggml_context * ctx) override {
3003+
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
3004+
ggml_set_name(input, "input");
3005+
3006+
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
3007+
ggml_set_name(kernel, "kernel");
3008+
3009+
ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
3010+
ggml_set_name(out, "out");
3011+
3012+
return out;
3013+
}
3014+
};
3015+
29873016
// GGML_OP_IM2COL
29883017
struct test_im2col : public test_case {
29893018
const ggml_type type_input;
@@ -4938,8 +4967,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
49384967

49394968
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
49404969

4941-
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
4942-
49434970
return test_cases;
49444971
}
49454972

0 commit comments

Comments
 (0)