Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,57 @@ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_ten
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);

return p; // offsets are initialized later in ggml_vk_op
}

struct vk_op_pad_push_constants {
uint32_t ne;
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
uint32_t misalign_offsets;

uint32_t lp0; uint32_t rp0;
uint32_t lp1; uint32_t rp1;
uint32_t lp2; uint32_t rp2;
uint32_t lp3; uint32_t rp3;
};

static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {
int64_t ne = ggml_nelements(dst);
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());

vk_op_pad_push_constants p{};
p.ne = (uint32_t)ne;

size_t src0_tsize = ggml_type_size(src0->type);
p.ne00 = (uint32_t)src0->ne[0];
p.ne01 = (uint32_t)src0->ne[1];
p.ne02 = (uint32_t)src0->ne[2];
p.ne03 = (uint32_t)src0->ne[3];
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);

size_t dst_tsize = ggml_type_size(dst->type);
p.ne10 = (uint32_t)dst->ne[0];
p.ne11 = (uint32_t)dst->ne[1];
p.ne12 = (uint32_t)dst->ne[2];
p.ne13 = (uint32_t)dst->ne[3];
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);

p.lp0 = dst->op_params[0];
p.rp0 = dst->op_params[1];
p.lp1 = dst->op_params[2];
p.rp1 = dst->op_params[3];
p.lp2 = dst->op_params[4];
p.rp2 = dst->op_params[5];
p.lp3 = dst->op_params[6];
p.rp3 = dst->op_params[7];

return p; // fastdiv values and offsets are initialized later in ggml_vk_op
}

Expand Down Expand Up @@ -3250,7 +3301,7 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

Expand Down Expand Up @@ -7829,6 +7880,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src2);
}

template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);

p.misalign_offsets = (a_offset << 16) | d_offset;

GGML_UNUSED(src1);
GGML_UNUSED(src2);
}

template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
Expand Down Expand Up @@ -8771,7 +8832,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
}

static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
}

Expand Down Expand Up @@ -12076,10 +12137,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_ACC:
case GGML_OP_CONCAT:
case GGML_OP_SCALE:
return true;
case GGML_OP_PAD:
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_ROLL:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
Expand Down Expand Up @@ -12520,7 +12578,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
} else if (tensor->op == GGML_OP_PAD) {
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
} else if (tensor->op == GGML_OP_REPEAT) {
tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
} else if (tensor->op == GGML_OP_REPEAT_BACK) {
Expand Down
27 changes: 24 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/pad.comp
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
#version 450

#include "types.comp"
#include "generic_unary_head.comp"

layout (push_constant) uniform parameter
{
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint misalign_offsets;

uint lp0; uint rp0;
uint lp1; uint rp1;
uint lp2; uint rp2;
uint lp3; uint rp3;
} p;

uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

Expand All @@ -19,10 +37,13 @@ void main() {
const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;

const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;

const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;

data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
}
18 changes: 15 additions & 3 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4703,21 +4703,28 @@ struct test_pad_ext : public test_case {
const int rp2;
const int lp3;
const int rp3;
const bool v;

std::string vars() override {
return VARS_TO_STR10(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
return VARS_TO_STR11(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v);
}

test_pad_ext(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1)
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3) {}
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
bool v = false)
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), v(v) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");

if (v) {
a = ggml_view_4d(ctx, a, (a->ne[0] + 1) / 2, (a->ne[1] + 1) / 2, (a->ne[2] + 1) / 2, (a->ne[3] + 1) / 2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
}

ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
ggml_set_name(out, "out");

Expand Down Expand Up @@ -6458,6 +6465,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());

for (bool v : {false, true}) {
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
}

for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) {
for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
Expand Down
Loading