Skip to content

Commit f98cb66

Browse files
Aclyggerganov
authored andcommitted
vulkan : implement ggml_roll (ggml/1290)
* vulkan : implement ggml_roll * vulkan : refactor vk_op_unary_push_constants initialization
1 parent 5ea5c58 commit f98cb66

File tree

3 files changed

+127
-95
lines changed

3 files changed

+127
-95
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 79 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ struct vk_device_struct {
417417
vk_pipeline pipeline_cos_f32;
418418
vk_pipeline pipeline_clamp_f32;
419419
vk_pipeline pipeline_pad_f32;
420+
vk_pipeline pipeline_roll_f32;
420421
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
421422
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
422423
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
@@ -687,6 +688,37 @@ struct vk_op_unary_push_constants {
687688
};
688689
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
689690

691+
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
692+
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
693+
ne = ne != 0 ? ne : ggml_nelements(dst);
694+
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
695+
696+
vk_op_unary_push_constants p{};
697+
p.ne = (uint32_t)ne;
698+
699+
size_t src0_tsize = ggml_type_size(src0->type);
700+
p.ne00 = (uint32_t)src0->ne[0];
701+
p.ne01 = (uint32_t)src0->ne[1];
702+
p.ne02 = (uint32_t)src0->ne[2];
703+
p.ne03 = (uint32_t)src0->ne[3];
704+
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
705+
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
706+
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
707+
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
708+
709+
size_t dst_tsize = ggml_type_size(dst->type);
710+
p.ne10 = (uint32_t)dst->ne[0];
711+
p.ne11 = (uint32_t)dst->ne[1];
712+
p.ne12 = (uint32_t)dst->ne[2];
713+
p.ne13 = (uint32_t)dst->ne[3];
714+
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
715+
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
716+
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
717+
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
718+
719+
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
720+
}
721+
690722
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
691723
// Precompute mp (m' in the paper) and L such that division
692724
// can be computed using a multiply (high 32b of 64b result)
@@ -2753,6 +2785,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27532785

27542786
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
27552787

2788+
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2789+
27562790
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
27572791
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
27582792

@@ -6425,6 +6459,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
64256459
return ctx->device->pipeline_pad_f32;
64266460
}
64276461
return nullptr;
6462+
case GGML_OP_ROLL:
6463+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6464+
return ctx->device->pipeline_roll_f32;
6465+
}
6466+
return nullptr;
64286467
case GGML_OP_REPEAT:
64296468
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
64306469
return ctx->device->pipeline_repeat_f32;
@@ -6965,6 +7004,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
69657004
case GGML_OP_COS:
69667005
case GGML_OP_CLAMP:
69677006
case GGML_OP_PAD:
7007+
case GGML_OP_ROLL:
69687008
case GGML_OP_REPEAT:
69697009
case GGML_OP_REPEAT_BACK:
69707010
case GGML_OP_CPY:
@@ -7416,117 +7456,60 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
74167456
}
74177457

74187458
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7419-
float * op_params = (float *)dst->op_params;
7420-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7421-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7459+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7460+
p.param1 = ggml_get_op_params_f32(dst, 0);
74227461

7423-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
7424-
(uint32_t)ggml_nelements(src0),
7425-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7426-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7427-
0,
7428-
op_params[0], 0.0f,
7429-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7430-
}, dryrun);
7462+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
74317463
}
74327464

74337465
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7434-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7435-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7436-
7437-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
7438-
(uint32_t)ggml_nelements(src0),
7439-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7440-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7441-
0,
7442-
0.0f, 0.0f,
7443-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7444-
}, dryrun);
7466+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
74457467
}
74467468

74477469
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7448-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7449-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7450-
7451-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
7452-
(uint32_t)ggml_nelements(src0),
7453-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7454-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7455-
0,
7456-
0.0f, 0.0f,
7457-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7458-
}, dryrun);
7470+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
74597471
}
74607472

74617473
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7462-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7463-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7464-
7465-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
7466-
(uint32_t)ggml_nelements(src0),
7467-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7468-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7469-
0,
7470-
0.0f, 0.0f,
7471-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7472-
}, dryrun);
7474+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
74737475
}
74747476

74757477
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7476-
float * op_params = (float *)dst->op_params;
7477-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7478-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7478+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7479+
p.param1 = ggml_get_op_params_f32(dst, 0);
7480+
p.param2 = ggml_get_op_params_f32(dst, 1);
74797481

7480-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
7481-
(uint32_t)ggml_nelements(src0),
7482-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7483-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7484-
0,
7485-
op_params[0], op_params[1],
7486-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7487-
}, dryrun);
7482+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
74887483
}
74897484

74907485
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7491-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7492-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7486+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7487+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
7488+
}
74937489

7494-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
7495-
(uint32_t)ggml_nelements(dst),
7496-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7497-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7498-
0,
7499-
0.0f, 0.0f,
7500-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7501-
}, dryrun);
7490+
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7491+
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
7492+
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
7493+
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
7494+
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
7495+
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
7496+
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
7497+
7498+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7499+
memcpy(&p.param1, &s01_packed, sizeof(float));
7500+
memcpy(&p.param2, &s23_packed, sizeof(float));
7501+
7502+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
75027503
}
75037504

75047505
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7505-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7506-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7507-
7508-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
7509-
(uint32_t)ggml_nelements(dst),
7510-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7511-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7512-
0,
7513-
0.0f, 0.0f,
7514-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7515-
}, dryrun);
7506+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7507+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
75167508
}
75177509

75187510
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7519-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7520-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7521-
7522-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
7523-
(uint32_t)ggml_nelements(dst),
7524-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7525-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7526-
0,
7527-
0.0f, 0.0f,
7528-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7529-
}, dryrun);
7511+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7512+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
75307513
}
75317514

75327515
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -7544,14 +7527,8 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
75447527
}
75457528
}
75467529

7547-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7548-
ne,
7549-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7550-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7551-
0,
7552-
0.0f, 0.0f,
7553-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7554-
}, dryrun);
7530+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
7531+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
75557532
}
75567533

75577534
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -8862,6 +8839,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
88628839
case GGML_OP_COS:
88638840
case GGML_OP_CLAMP:
88648841
case GGML_OP_PAD:
8842+
case GGML_OP_ROLL:
88658843
case GGML_OP_CPY:
88668844
case GGML_OP_CONT:
88678845
case GGML_OP_DUP:
@@ -9031,6 +9009,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90319009
case GGML_OP_PAD:
90329010
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
90339011

9012+
break;
9013+
case GGML_OP_ROLL:
9014+
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
9015+
90349016
break;
90359017
case GGML_OP_CPY:
90369018
case GGML_OP_CONT:
@@ -9247,6 +9229,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
92479229
case GGML_OP_COS:
92489230
case GGML_OP_CLAMP:
92499231
case GGML_OP_PAD:
9232+
case GGML_OP_ROLL:
92509233
case GGML_OP_CPY:
92519234
case GGML_OP_CONT:
92529235
case GGML_OP_DUP:
@@ -10368,6 +10351,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1036810351
case GGML_OP_CONCAT:
1036910352
case GGML_OP_SCALE:
1037010353
case GGML_OP_PAD:
10354+
case GGML_OP_ROLL:
1037110355
case GGML_OP_DIAG_MASK_INF:
1037210356
case GGML_OP_SOFT_MAX:
1037310357
case GGML_OP_SOFT_MAX_BACK:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_unary_head.comp"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
uint wrap_idx(int i, uint ne) {
9+
if (i < 0) {
10+
return i + ne;
11+
} else if (i >= ne) {
12+
return i - ne;
13+
}
14+
return i;
15+
}
16+
17+
void main() {
18+
const uint idx = get_idx();
19+
if (idx >= p.ne) {
20+
return;
21+
}
22+
23+
const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
24+
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
25+
const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
26+
const uint i2_offset = i2*p.ne11*p.ne10;
27+
const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
28+
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
29+
30+
const uint p1 = floatBitsToUint(p.param1);
31+
const uint p2 = floatBitsToUint(p.param2);
32+
const int s0 = int(p1 >> 16) - 0x8000;
33+
const int s1 = int(p1 & 0xFFFF) - 0x8000;
34+
const int s2 = int(p2 >> 16) - 0x8000;
35+
const int s3 = int(p2 & 0xFFFF) - 0x8000;
36+
37+
const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
38+
const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
39+
const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
40+
const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
41+
42+
const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
43+
const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
44+
45+
data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
46+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,8 @@ void process_shaders() {
642642
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
643643
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
644644

645+
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
646+
645647
for (auto &c : compiles) {
646648
c.wait();
647649
}

0 commit comments

Comments
 (0)