Skip to content

Commit fc958ce

Browse files
Aclyggerganov
authored andcommitted
vulkan : implement ggml_roll (ggml/1290)
ggml-ci
1 parent f5e96b3 commit fc958ce

File tree

4 files changed

+154
-95
lines changed

4 files changed

+154
-95
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 79 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ struct vk_device_struct {
432432
vk_pipeline pipeline_cos_f32;
433433
vk_pipeline pipeline_clamp_f32;
434434
vk_pipeline pipeline_pad_f32;
435+
vk_pipeline pipeline_roll_f32;
435436
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
436437
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
437438
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
@@ -693,6 +694,37 @@ struct vk_op_unary_push_constants {
693694
};
694695
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
695696

697+
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
698+
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
699+
ne = ne != 0 ? ne : ggml_nelements(dst);
700+
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
701+
702+
vk_op_unary_push_constants p{};
703+
p.ne = (uint32_t)ne;
704+
705+
size_t src0_tsize = ggml_type_size(src0->type);
706+
p.ne00 = (uint32_t)src0->ne[0];
707+
p.ne01 = (uint32_t)src0->ne[1];
708+
p.ne02 = (uint32_t)src0->ne[2];
709+
p.ne03 = (uint32_t)src0->ne[3];
710+
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
711+
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
712+
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
713+
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
714+
715+
size_t dst_tsize = ggml_type_size(dst->type);
716+
p.ne10 = (uint32_t)dst->ne[0];
717+
p.ne11 = (uint32_t)dst->ne[1];
718+
p.ne12 = (uint32_t)dst->ne[2];
719+
p.ne13 = (uint32_t)dst->ne[3];
720+
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
721+
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
722+
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
723+
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
724+
725+
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
726+
}
727+
696728
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
697729
// Precompute mp (m' in the paper) and L such that division
698730
// can be computed using a multiply (high 32b of 64b result)
@@ -2802,6 +2834,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
28022834

28032835
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28042836

2837+
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2838+
28052839
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28062840
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28072841

@@ -6502,6 +6536,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65026536
return ctx->device->pipeline_pad_f32;
65036537
}
65046538
return nullptr;
6539+
case GGML_OP_ROLL:
6540+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6541+
return ctx->device->pipeline_roll_f32;
6542+
}
6543+
return nullptr;
65056544
case GGML_OP_REPEAT:
65066545
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
65076546
return ctx->device->pipeline_repeat_f32;
@@ -7048,6 +7087,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
70487087
case GGML_OP_COS:
70497088
case GGML_OP_CLAMP:
70507089
case GGML_OP_PAD:
7090+
case GGML_OP_ROLL:
70517091
case GGML_OP_REPEAT:
70527092
case GGML_OP_REPEAT_BACK:
70537093
case GGML_OP_CPY:
@@ -7499,117 +7539,61 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
74997539
}
75007540

75017541
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7502-
float * op_params = (float *)dst->op_params;
7503-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7504-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7542+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7543+
p.param1 = ggml_get_op_params_f32(dst, 0);
7544+
p.param2 = ggml_get_op_params_f32(dst, 1);
75057545

7506-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
7507-
(uint32_t)ggml_nelements(src0),
7508-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7509-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7510-
0,
7511-
op_params[0], op_params[1],
7512-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7513-
}, dryrun);
7546+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
75147547
}
75157548

75167549
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7517-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7518-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7519-
7520-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
7521-
(uint32_t)ggml_nelements(src0),
7522-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7523-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7524-
0,
7525-
0.0f, 0.0f,
7526-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7527-
}, dryrun);
7550+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
75287551
}
75297552

75307553
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7531-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7532-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7533-
7534-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
7535-
(uint32_t)ggml_nelements(src0),
7536-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7537-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7538-
0,
7539-
0.0f, 0.0f,
7540-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7541-
}, dryrun);
7554+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
75427555
}
75437556

75447557
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7545-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7546-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7547-
7548-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
7549-
(uint32_t)ggml_nelements(src0),
7550-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7551-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7552-
0,
7553-
0.0f, 0.0f,
7554-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7555-
}, dryrun);
7558+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
75567559
}
75577560

75587561
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7559-
float * op_params = (float *)dst->op_params;
7560-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7561-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7562+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7563+
p.param1 = ggml_get_op_params_f32(dst, 0);
7564+
p.param2 = ggml_get_op_params_f32(dst, 1);
75627565

7563-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
7564-
(uint32_t)ggml_nelements(src0),
7565-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7566-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7567-
0,
7568-
op_params[0], op_params[1],
7569-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7570-
}, dryrun);
7566+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
75717567
}
75727568

75737569
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7574-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7575-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7570+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7571+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
7572+
}
75767573

7577-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
7578-
(uint32_t)ggml_nelements(dst),
7579-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7580-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7581-
0,
7582-
0.0f, 0.0f,
7583-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7584-
}, dryrun);
7574+
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7575+
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
7576+
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
7577+
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
7578+
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
7579+
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
7580+
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
7581+
7582+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7583+
memcpy(&p.param1, &s01_packed, sizeof(float));
7584+
memcpy(&p.param2, &s23_packed, sizeof(float));
7585+
7586+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
75857587
}
75867588

75877589
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7588-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7589-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7590-
7591-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
7592-
(uint32_t)ggml_nelements(dst),
7593-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7594-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7595-
0,
7596-
0.0f, 0.0f,
7597-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7598-
}, dryrun);
7590+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7591+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
75997592
}
76007593

76017594
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7602-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7603-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7604-
7605-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
7606-
(uint32_t)ggml_nelements(dst),
7607-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7608-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7609-
0,
7610-
0.0f, 0.0f,
7611-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7612-
}, dryrun);
7595+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7596+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
76137597
}
76147598

76157599
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -7627,14 +7611,8 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
76277611
}
76287612
}
76297613

7630-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7631-
ne,
7632-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7633-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7634-
0,
7635-
0.0f, 0.0f,
7636-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7637-
}, dryrun);
7614+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
7615+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
76387616
}
76397617

76407618
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -8956,6 +8934,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
89568934
case GGML_OP_COS:
89578935
case GGML_OP_CLAMP:
89588936
case GGML_OP_PAD:
8937+
case GGML_OP_ROLL:
89598938
case GGML_OP_CPY:
89608939
case GGML_OP_CONT:
89618940
case GGML_OP_DUP:
@@ -9125,6 +9104,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91259104
case GGML_OP_PAD:
91269105
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
91279106

9107+
break;
9108+
case GGML_OP_ROLL:
9109+
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
9110+
91289111
break;
91299112
case GGML_OP_CPY:
91309113
case GGML_OP_CONT:
@@ -9345,6 +9328,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
93459328
case GGML_OP_COS:
93469329
case GGML_OP_CLAMP:
93479330
case GGML_OP_PAD:
9331+
case GGML_OP_ROLL:
93489332
case GGML_OP_CPY:
93499333
case GGML_OP_CONT:
93509334
case GGML_OP_DUP:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_unary_head.comp"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
uint wrap_idx(int i, uint ne) {
9+
if (i < 0) {
10+
return i + ne;
11+
} else if (i >= ne) {
12+
return i - ne;
13+
}
14+
return i;
15+
}
16+
17+
void main() {
18+
const uint idx = get_idx();
19+
if (idx >= p.ne) {
20+
return;
21+
}
22+
23+
const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
24+
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
25+
const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
26+
const uint i2_offset = i2*p.ne11*p.ne10;
27+
const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
28+
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
29+
30+
const uint p1 = floatBitsToUint(p.param1);
31+
const uint p2 = floatBitsToUint(p.param2);
32+
const int s0 = int(p1 >> 16) - 0x8000;
33+
const int s1 = int(p1 & 0xFFFF) - 0x8000;
34+
const int s2 = int(p2 >> 16) - 0x8000;
35+
const int s3 = int(p2 & 0xFFFF) - 0x8000;
36+
37+
const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
38+
const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
39+
const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
40+
const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
41+
42+
const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
43+
const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
44+
45+
data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
46+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ void process_shaders() {
648648
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
649649
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
650650

651+
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
652+
651653
for (auto &c : compiles) {
652654
c.wait();
653655
}

0 commit comments

Comments
 (0)