Skip to content

Commit b6c4a11

Browse files
committed
vulkan : support sum, sum_rows and mean with non-contiguous tensors
1 parent 0012b5c commit b6c4a11

File tree

2 files changed

+94
-10
lines changed

2 files changed

+94
-10
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,40 @@ struct vk_op_upscale_push_constants {
10141014
float sf0; float sf1; float sf2; float sf3;
10151015
};
10161016

1017+
struct vk_op_sum_rows_push_constants
1018+
{
1019+
uint32_t n_cols;
1020+
uint32_t ne01, ne02;
1021+
uint32_t nb00, nb01, nb02, nb03;
1022+
uint32_t nb11, nb12, nb13;
1023+
float weight;
1024+
uint32_t misalign_offsets;
1025+
uint32_t ne0_12mp, ne0_12L;
1026+
uint32_t ne0_1mp, ne0_1L;
1027+
};
1028+
1029+
vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
1030+
uint32_t type_size = (uint32_t)ggml_type_size(src->type);
1031+
vk_op_sum_rows_push_constants p = {};
1032+
p.n_cols = (uint32_t)n_cols;
1033+
p.ne01 = (uint32_t)src->ne[1];
1034+
p.ne02 = (uint32_t)src->ne[2];
1035+
p.nb00 = (uint32_t)src->nb[0] / type_size;
1036+
p.nb01 = (uint32_t)src->nb[1] / type_size;
1037+
p.nb02 = (uint32_t)src->nb[2] / type_size;
1038+
p.nb03 = (uint32_t)src->nb[3] / type_size;
1039+
p.nb11 = (uint32_t)dst->nb[1] / type_size;
1040+
p.nb12 = (uint32_t)dst->nb[2] / type_size;
1041+
p.nb13 = (uint32_t)dst->nb[3] / type_size;
1042+
p.weight = 1.0f;
1043+
return p;
1044+
}
1045+
1046+
template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
1047+
init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
1048+
init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
1049+
}
1050+
10171051
// Allow pre-recording command buffers
10181052
struct vk_staging_memcpy {
10191053
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -3122,7 +3156,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
31223156

31233157
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
31243158

3125-
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3159+
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
31263160

31273161
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
31283162

@@ -7340,6 +7374,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
73407374
case GGML_OP_CONV_2D_DW:
73417375
case GGML_OP_IM2COL:
73427376
case GGML_OP_SET_ROWS:
7377+
case GGML_OP_SUM:
7378+
case GGML_OP_SUM_ROWS:
7379+
case GGML_OP_MEAN:
73437380
return true;
73447381
default:
73457382
return false;
@@ -7374,6 +7411,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
73747411
GGML_UNUSED(src2);
73757412
}
73767413

7414+
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7415+
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
7416+
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
7417+
7418+
p.misalign_offsets = (a_offset << 16) | d_offset;
7419+
7420+
GGML_UNUSED(src1);
7421+
GGML_UNUSED(src2);
7422+
}
7423+
73777424
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
73787425
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
73797426
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@@ -8542,15 +8589,20 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
85428589
}
85438590

85448591
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8545-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 1.0f, 0.0f }, dryrun);
8592+
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
8593+
p.nb00 = 1; // treat src0 as flattened 1D tensor
8594+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
85468595
}
85478596

85488597
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8549-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 1.0f, 0.0f }, dryrun);
8598+
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
8599+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
85508600
}
85518601

85528602
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8553-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, { (uint32_t)src0->ne[0], 0, 1.0f / (float)src0->ne[0], 0.0f }, dryrun);
8603+
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
8604+
p.weight = 1.0f / (float)src0->ne[0];
8605+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
85548606
}
85558607

85568608
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,59 @@
11
#version 450
22

3-
#include "generic_head.comp"
43
#include "types.comp"
54

65
#extension GL_EXT_control_flow_attributes : enable
6+
77
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88

99
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1010
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
1111

1212
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
1313

14+
layout (push_constant) uniform parameter
15+
{
16+
uint n_cols;
17+
uint ne01, ne02;
18+
uint nb00, nb01, nb02, nb03;
19+
uint nb11, nb12, nb13;
20+
float weight;
21+
uint misalign_offsets;
22+
uint ne0_12mp, ne0_12L;
23+
uint ne0_1mp, ne0_1L;
24+
} p;
25+
26+
uint get_aoffset() { return p.misalign_offsets >> 16; }
27+
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
28+
29+
// see init_fastdiv_values in ggml-vulkan.cpp
30+
uint fastdiv(uint n, uint mp, uint L) {
31+
uint msbs, lsbs;
32+
// msbs = mulhi(n, mp)
33+
umulExtended(n, mp, msbs, lsbs);
34+
return (msbs + n) >> L;
35+
}
36+
37+
1438
shared FLOAT_TYPE tmp[BLOCK_SIZE];
1539

1640
void main() {
1741
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
1842
const uint col = gl_LocalInvocationID.x;
19-
const float weight = p.param1;
43+
const float weight = p.weight;
44+
45+
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
46+
const uint i03_offset = i03 * p.ne01*p.ne02;
47+
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
48+
const uint i01 = row - i03_offset - i02*p.ne01;
49+
50+
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
51+
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
2052

21-
tmp[col] = FLOAT_TYPE(0.0f);
53+
tmp[col] = FLOAT_TYPE(0.0);
2254

23-
for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
24-
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
55+
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
56+
tmp[col] += FLOAT_TYPE(data_a[src_idx + i * p.nb00]);
2557
}
2658

2759
barrier();
@@ -33,6 +65,6 @@ void main() {
3365
}
3466

3567
if (col == 0) {
36-
data_d[row] = D_TYPE(tmp[0] * weight);
68+
data_d[dst_idx] = D_TYPE(tmp[0] * weight);
3769
}
3870
}

0 commit comments

Comments
 (0)