Skip to content

Commit 6694ab6

Browse files
committed
* ggml-vulkan: adds op CONV_TRANSPOSE_1D
* test-backend-ops: adds more spohisticated tests for CONV_TRANSPOSE_1D
1 parent cdf94a1 commit 6694ab6

File tree

4 files changed

+185
-2
lines changed

4 files changed

+185
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ struct vk_device_struct {
398398
vk_pipeline pipeline_count_equal_i32;
399399
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
400400
vk_pipeline pipeline_timestep_embedding_f32;
401+
vk_pipeline pipeline_conv_transpose_1d_f32;
401402
vk_pipeline pipeline_pool2d_f32;
402403
vk_pipeline pipeline_rwkv_wkv6_f32;
403404
vk_pipeline pipeline_rwkv_wkv7_f32;
@@ -706,6 +707,21 @@ struct vk_op_timestep_embedding_push_constants {
706707
uint32_t max_period;
707708
};
708709

710+
struct vk_op_conv_transpose_1d_push_constants {
711+
uint32_t Cout;
712+
uint32_t Cin;
713+
uint32_t K;
714+
uint32_t L;
715+
uint32_t KL;
716+
717+
uint32_t nb01;
718+
uint32_t nb02;
719+
uint32_t nb11;
720+
uint32_t nb1;
721+
722+
int32_t s0;
723+
};
724+
709725
struct vk_op_pool2d_push_constants {
710726
uint32_t IW; uint32_t IH;
711727
uint32_t OW; uint32_t OH;
@@ -2727,6 +2743,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27272743

27282744
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
27292745

2746+
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
2747+
27302748
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
27312749

27322750
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
@@ -6391,6 +6409,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
63916409
return ctx->device->pipeline_timestep_embedding_f32;
63926410
}
63936411
return nullptr;
6412+
case GGML_OP_CONV_TRANSPOSE_1D:
6413+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6414+
return ctx->device->pipeline_conv_transpose_1d_f32;
6415+
}
6416+
return nullptr;
63946417
case GGML_OP_POOL_2D:
63956418
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
63966419
return ctx->device->pipeline_pool2d_f32;
@@ -6725,6 +6748,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
67256748
uint32_t half_ceil = (dim + 1) / 2;
67266749
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
67276750
} break;
6751+
case GGML_OP_CONV_TRANSPOSE_1D:
6752+
{
6753+
elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
6754+
} break;
67286755
case GGML_OP_POOL_2D:
67296756
{
67306757
const uint32_t N = dst->ne[3];
@@ -7528,6 +7555,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
75287555
}, dryrun);
75297556
}
75307557

7558+
static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7559+
// src0: (K, Cout, Cin, 1) -- kernel
7560+
// src1: (L, Cin, 1, 1) -- input
7561+
// dst: (*, Cout, 1, 1)
7562+
7563+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
7564+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
7565+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
7566+
7567+
GGML_TENSOR_BINARY_OP_LOCALS
7568+
7569+
GGML_ASSERT(nb00 == sizeof(float));
7570+
GGML_ASSERT(nb10 == sizeof(float));
7571+
7572+
const int32_t s0 = dst->op_params[0];
7573+
7574+
vk_op_conv_transpose_1d_push_constants p{};
7575+
p.Cout = ne01;
7576+
p.Cin = ne02;
7577+
p.K = ne00;
7578+
p.L = ne10;
7579+
p.KL = ne0;
7580+
p.nb01 = nb01 / nb00;
7581+
p.nb02 = nb02 / nb00;
7582+
p.nb11 = nb11 / nb10;
7583+
p.nb1 = nb1 / nb0;
7584+
p.s0 = s0;
7585+
7586+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
7587+
}
7588+
75317589
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
75327590
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
75337591
const int32_t k1 = dst->op_params[1];
@@ -8599,6 +8657,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
85998657
case GGML_OP_COUNT_EQUAL:
86008658
case GGML_OP_IM2COL:
86018659
case GGML_OP_TIMESTEP_EMBEDDING:
8660+
case GGML_OP_CONV_TRANSPOSE_1D:
86028661
case GGML_OP_POOL_2D:
86038662
case GGML_OP_CONV_2D_DW:
86048663
case GGML_OP_RWKV_WKV6:
@@ -8663,6 +8722,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
86638722
case GGML_OP_COUNT_EQUAL:
86648723
case GGML_OP_IM2COL:
86658724
case GGML_OP_TIMESTEP_EMBEDDING:
8725+
case GGML_OP_CONV_TRANSPOSE_1D:
86668726
case GGML_OP_POOL_2D:
86678727
case GGML_OP_CONV_2D_DW:
86688728
case GGML_OP_LEAKY_RELU:
@@ -8834,6 +8894,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
88348894
case GGML_OP_TIMESTEP_EMBEDDING:
88358895
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
88368896

8897+
break;
8898+
case GGML_OP_CONV_TRANSPOSE_1D:
8899+
ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun);
8900+
88378901
break;
88388902
case GGML_OP_POOL_2D:
88398903
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
@@ -8962,6 +9026,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
89629026
case GGML_OP_COUNT_EQUAL:
89639027
case GGML_OP_IM2COL:
89649028
case GGML_OP_TIMESTEP_EMBEDDING:
9029+
case GGML_OP_CONV_TRANSPOSE_1D:
89659030
case GGML_OP_POOL_2D:
89669031
case GGML_OP_CONV_2D_DW:
89679032
case GGML_OP_RWKV_WKV6:
@@ -9964,6 +10029,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
996410029
case GGML_OP_COUNT_EQUAL:
996510030
case GGML_OP_IM2COL:
996610031
case GGML_OP_TIMESTEP_EMBEDDING:
10032+
case GGML_OP_CONV_TRANSPOSE_1D:
10033+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
996710034
case GGML_OP_CONV_2D_DW:
996810035
case GGML_OP_POOL_2D:
996910036
case GGML_OP_RWKV_WKV6:
@@ -10462,6 +10529,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1046210529
const int32_t dim = tensor->op_params[0];
1046310530
const int32_t max_period = tensor->op_params[1];
1046410531
tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
10532+
} else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
10533+
const int32_t s0 = tensor->op_params[0];
10534+
const int32_t p0 = tensor->op_params[1];
10535+
const int32_t d0 = tensor->op_params[2];
10536+
tensor_clonse = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
1046510537
} else if (tensor->op == GGML_OP_POOL_2D) {
1046610538
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
1046710539
const int32_t k0 = tensor->op_params[1];
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
5+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
6+
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
7+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
8+
9+
layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
10+
11+
layout (push_constant) uniform parameter {
12+
uint32_t Cout;
13+
uint32_t Cin;
14+
uint32_t K;
15+
uint32_t L;
16+
uint32_t KL;
17+
18+
uint32_t nb01;
19+
uint32_t nb02;
20+
uint32_t nb11;
21+
uint32_t nb1;
22+
23+
int32_t s0;
24+
} p;
25+
26+
27+
uint32_t Cout_idx = gl_WorkGroupID.x;
28+
const uint32_t bs = gl_WorkGroupSize.x;
29+
uint32_t tid = gl_LocalInvocationID.x;
30+
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
31+
uint32_t tmp_len = bs*p.s0+p.K;
32+
shared D_TYPE tmp[1024];
33+
34+
uint splitWork(uint workSize){
35+
return (bs + workSize -1) / bs;
36+
}
37+
38+
void main(){
39+
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
40+
uint32_t idx = i*bs+tid;
41+
if(idx < tmp_len){
42+
tmp[idx] = 0.0;
43+
}
44+
}
45+
46+
uint32_t L_blocks = splitWork(p.L);
47+
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
48+
if(L_block_id > 0){
49+
// Shift values in tmp to the current processing window
50+
for(int i = 0; i < splitWork(tmp_len); i++){
51+
uint32_t idx = i*bs+tid;
52+
if(idx >= bs*p.s0 && idx < tmp_len){
53+
tmp[idx-bs*p.s0] = tmp[idx];
54+
tmp[idx] = 0.0;
55+
}else if(idx >= p.K && idx < bs*p.s0){
56+
tmp[idx] = 0.0;
57+
}
58+
}
59+
}
60+
barrier();
61+
62+
// Save contributions of the block to tmp
63+
uint32_t L_idx = L_block_id*bs + tid;
64+
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
65+
D_TYPE dp = 0.0;
66+
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
67+
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
68+
if(L_idx < p.L){
69+
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
70+
dp = fma(elemKrn, elemInp, dp);
71+
}
72+
}
73+
tmp[tid*p.s0 + K_idx] += dp;
74+
barrier();
75+
}
76+
77+
// Save the computed values except the last block that can have different size
78+
uint32_t KLb_idx = L_block_id*bs*p.s0;
79+
if(L_block_id < L_blocks-1){
80+
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
81+
uint32_t sh_idx = p.s0*tid+s0_idx;
82+
uint32_t KL_idx = KLb_idx+sh_idx;
83+
if(KL_idx < p.KL){
84+
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
85+
}
86+
}
87+
}
88+
}
89+
90+
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
91+
uint32_t idx = i*bs+tid;
92+
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
93+
if(KL_idx < p.KL){
94+
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
95+
}
96+
}
97+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,8 @@ void process_shaders() {
622622

623623
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
624624

625+
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
626+
625627
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
626628

627629
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

tests/test-backend-ops.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,8 +2706,8 @@ struct test_conv_transpose_1d : public test_case {
27062706
return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
27072707
}
27082708

2709-
test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
2710-
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
2709+
test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_channels, 1 /* assert in cpu kernel*/, 1 (should be batch)]
2710+
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, output_channels, input_channels, 1 (should be batch)]
27112711
int s0 = 1, int p0 = 0, int d0 = 1)
27122712
: ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}
27132713

@@ -4029,6 +4029,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40294029
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
40304030
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
40314031

4032+
for(uint32_t Cout : {1, 9}){
4033+
for(uint32_t Cin : {1, 7}){
4034+
for(uint32_t K : {1, 2, 3, 5, 6, 8, 9, 28}){
4035+
for(uint32_t L : {1, 2, 3, 5, 15, 16, 60, 100, 111, 127, 128, 157, 255, 376, 1024, 1173}){
4036+
for(uint32_t s0: {1, 2, 3}){
4037+
test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1));
4038+
}
4039+
}
4040+
}
4041+
}
4042+
}
4043+
40324044
test_cases.emplace_back(new test_conv_transpose_1d());
40334045
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
40344046
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));

0 commit comments

Comments
 (0)