@@ -803,6 +803,57 @@ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_ten
803803 p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
804804 p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
805805
806+ return p; // offsets are initialized later in ggml_vk_op
807+ }
808+
809+ struct vk_op_pad_push_constants {
810+ uint32_t ne;
811+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
812+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
813+ uint32_t misalign_offsets;
814+
815+ uint32_t lp0; uint32_t rp0;
816+ uint32_t lp1; uint32_t rp1;
817+ uint32_t lp2; uint32_t rp2;
818+ uint32_t lp3; uint32_t rp3;
819+ };
820+
821+ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {
822+ int64_t ne = ggml_nelements(dst);
823+ GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
824+
825+ vk_op_pad_push_constants p{};
826+ p.ne = (uint32_t)ne;
827+
828+ size_t src0_tsize = ggml_type_size(src0->type);
829+ p.ne00 = (uint32_t)src0->ne[0];
830+ p.ne01 = (uint32_t)src0->ne[1];
831+ p.ne02 = (uint32_t)src0->ne[2];
832+ p.ne03 = (uint32_t)src0->ne[3];
833+ p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
834+ p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
835+ p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
836+ p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
837+
838+ size_t dst_tsize = ggml_type_size(dst->type);
839+ p.ne10 = (uint32_t)dst->ne[0];
840+ p.ne11 = (uint32_t)dst->ne[1];
841+ p.ne12 = (uint32_t)dst->ne[2];
842+ p.ne13 = (uint32_t)dst->ne[3];
843+ p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
844+ p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
845+ p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
846+ p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
847+
848+ p.lp0 = dst->op_params[0];
849+ p.rp0 = dst->op_params[1];
850+ p.lp1 = dst->op_params[2];
851+ p.rp1 = dst->op_params[3];
852+ p.lp2 = dst->op_params[4];
853+ p.rp2 = dst->op_params[5];
854+ p.lp3 = dst->op_params[6];
855+ p.rp3 = dst->op_params[7];
856+
806857 return p; // fastdiv values and offsets are initialized later in ggml_vk_op
807858}
808859
@@ -3250,7 +3301,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
32503301
32513302 ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
32523303
3253- ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants ), {512, 1, 1}, {}, 1);
3304+ ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants ), {512, 1, 1}, {}, 1);
32543305
32553306 ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
32563307
@@ -7829,6 +7880,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
78297880 GGML_UNUSED(src2);
78307881}
78317882
7883+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7884+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
7885+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
7886+
7887+ p.misalign_offsets = (a_offset << 16) | d_offset;
7888+
7889+ GGML_UNUSED(src1);
7890+ GGML_UNUSED(src2);
7891+ }
7892+
78327893template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
78337894 const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
78347895 const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@@ -8771,7 +8832,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
87718832}
87728833
87738834static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8774- vk_op_unary_push_constants p = vk_op_unary_push_constants_init (src0, dst, ggml_nelements(dst) );
8835+ vk_op_pad_push_constants p = vk_op_pad_push_constants_init (src0, dst);
87758836 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
87768837}
87778838
@@ -12076,10 +12137,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1207612137 case GGML_OP_ACC:
1207712138 case GGML_OP_CONCAT:
1207812139 case GGML_OP_SCALE:
12079- return true;
1208012140 case GGML_OP_PAD:
12081- return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
12082- (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
1208312141 case GGML_OP_ROLL:
1208412142 case GGML_OP_DIAG_MASK_INF:
1208512143 case GGML_OP_SOFT_MAX:
@@ -12520,7 +12578,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1252012578 const float * params = (const float *)tensor->op_params;
1252112579 tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
1252212580 } else if (tensor->op == GGML_OP_PAD) {
12523- tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
12581+ tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
12582+ tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
1252412583 } else if (tensor->op == GGML_OP_REPEAT) {
1252512584 tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
1252612585 } else if (tensor->op == GGML_OP_REPEAT_BACK) {
0 commit comments