Skip to content

Commit ec0b188

Browse files
authored
vulkan: Support ne[3]>1 in noncontig matrix-vector multiply (#15015)
1 parent 339bd02 commit ec0b188

File tree

3 files changed

+30
-19
lines changed

3 files changed

+30
-19
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2885,7 +2885,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
28852885
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
28862886
}
28872887
}
2888-
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2888+
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
28892889

28902890
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
28912891
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -5821,7 +5821,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
58215821
const uint64_t ne00 = src0->ne[0];
58225822
const uint64_t ne01 = src0->ne[1];
58235823
const uint64_t ne02 = src0->ne[2];
5824-
// const uint64_t ne03 = src0->ne[3];
5824+
const uint64_t ne03 = src0->ne[3];
58255825

58265826
const uint64_t nb01 = src0->nb[1];
58275827
const uint64_t nb02 = src0->nb[2];
@@ -5833,7 +5833,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
58335833
const uint64_t ne12 = src1->ne[2];
58345834
// const uint64_t ne13 = src1->ne[3];
58355835

5836+
const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
5837+
const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
5838+
const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
5839+
58365840
GGML_ASSERT(ne11 == 1);
5841+
GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
58375842

58385843
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
58395844
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
@@ -5849,7 +5854,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
58495854
src1_uma = d_Qy != nullptr;
58505855
}
58515856

5852-
const uint64_t d_ne = ne01 * ne11 * ne12;
5857+
const uint64_t d_ne = ne01 * ne11 * ne12 * ne03;
58535858

58545859
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
58555860
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
@@ -5884,10 +5889,10 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
58845889
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
58855890

58865891
// compute
5887-
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
5892+
const std::array<uint32_t, 12> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 };
58885893
ggml_vk_sync_buffers(subctx);
58895894
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
5890-
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
5895+
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
58915896
}
58925897

58935898
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ layout (push_constant) uniform parameter
2626
uint ne12;
2727
uint b_offset;
2828
uint d_offset;
29+
uint nb03;
30+
uint nb13;
31+
uint nb23;
2932
} p;
3033

3134
shared FLOAT_TYPE tmp[BLOCK_SIZE];
@@ -34,14 +37,15 @@ void main() {
3437
const uint tid = gl_LocalInvocationID.x;
3538
const uint row_x = gl_GlobalInvocationID.y;
3639
const uint channel = gl_GlobalInvocationID.z;
40+
const uint i3 = gl_WorkGroupID.x;
3741
const uint channel_x = channel / p.channel_x_divisor;
3842
const uint channel_y = channel % p.ne12;
3943

4044
const uint nrows_y = p.ncols_x;
4145
const uint nrows_dst = p.nrows_x;
4246
const uint row_dst = row_x;
4347

44-
const uint idst = channel*nrows_dst + row_dst;
48+
const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;
4549

4650
FLOAT_TYPE temp = 0.0f;
4751

@@ -58,8 +62,8 @@ void main() {
5862

5963
const uint row_y = col_x;
6064

61-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
62-
const uint iy = channel_y*p.channel_stride_y + row_y;
65+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
66+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
6367

6468
const vec4 av4 = vec4(data_a_v4[ix / 4]);
6569
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -74,8 +78,8 @@ void main() {
7478

7579
const uint row_y = col_x;
7680

77-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
78-
const uint iy = channel_y*p.channel_stride_y + row_y;
81+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
82+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
7983

8084
const vec4 av4 = vec4(data_a_v4[ix / 4]);
8185
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -91,8 +95,8 @@ void main() {
9195

9296
const uint row_y = col_x;
9397

94-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
95-
const uint iy = channel_y*p.channel_stride_y + row_y;
98+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
99+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
96100

97101
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
98102

tests/test-backend-ops.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5592,13 +5592,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
55925592
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
55935593
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
55945594

5595-
for (auto bs : {1,2,4,8}) {
5596-
for (auto nr : {1,4}) {
5597-
for (uint32_t m = 0; m < 2; ++m) {
5598-
for (uint32_t k = 0; k < 2; ++k) {
5599-
for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
5600-
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
5601-
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
5595+
for (auto bs2 : {1,3}) {
5596+
for (auto bs : {1,2,4,8}) {
5597+
for (auto nr : {1,4}) {
5598+
for (uint32_t m = 0; m < 2; ++m) {
5599+
for (uint32_t k = 0; k < 2; ++k) {
5600+
for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
5601+
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, bs2}, {nr, 1}, {0, 2, 1, 3}));
5602+
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, bs2}, {nr, 1}, {0, 1, 2, 3}, true));
5603+
}
56025604
}
56035605
}
56045606
}

0 commit comments

Comments
 (0)