Skip to content

Commit b3b03a7

Browse files
authored
vulkan: Implement GGML_OP_CUMSUM (#17479)
1 parent 583cb83 commit b3b03a7

File tree

5 files changed

+125
-24
lines changed

5 files changed

+125
-24
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ struct vk_device_struct {
705705
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
706706
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
707707
vk_pipeline pipeline_sum_rows_f32;
708+
vk_pipeline pipeline_cumsum_f32;
708709
vk_pipeline pipeline_argmax_f32;
709710
vk_pipeline pipeline_count_equal_i32;
710711
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -3968,6 +3969,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
39683969

39693970
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
39703971

3972+
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
3973+
39713974
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
39723975

39733976
#define IM2COL(bda) \
@@ -8457,6 +8460,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
84578460
return ctx->device->pipeline_sum_rows_f32;
84588461
}
84598462
return nullptr;
8463+
case GGML_OP_CUMSUM:
8464+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8465+
return ctx->device->pipeline_cumsum_f32;
8466+
}
8467+
return nullptr;
84608468
case GGML_OP_ARGMAX:
84618469
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
84628470
return ctx->device->pipeline_argmax_f32;
@@ -8821,6 +8829,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88218829
case GGML_OP_SOFT_MAX:
88228830
case GGML_OP_SOFT_MAX_BACK:
88238831
case GGML_OP_SUM_ROWS:
8832+
case GGML_OP_CUMSUM:
88248833
case GGML_OP_MEAN:
88258834
case GGML_OP_ARGMAX:
88268835
{
@@ -10150,6 +10159,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1015010159
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
1015110160
}
1015210161

10162+
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10163+
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
10164+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
10165+
}
10166+
1015310167
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1015410168
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
1015510169
}
@@ -11749,6 +11763,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1174911763
case GGML_OP_SUM_ROWS:
1175011764
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
1175111765

11766+
break;
11767+
case GGML_OP_CUMSUM:
11768+
ggml_vk_cumsum(ctx, compute_ctx, src0, node);
11769+
1175211770
break;
1175311771
case GGML_OP_MEAN:
1175411772
ggml_vk_mean(ctx, compute_ctx, src0, node);
@@ -13786,6 +13804,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1378613804
case GGML_OP_SUM_ROWS:
1378713805
case GGML_OP_MEAN:
1378813806
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
13807+
case GGML_OP_CUMSUM:
13808+
{
13809+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13810+
auto device = ggml_vk_get_device(ctx->device);
13811+
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
13812+
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
13813+
}
13814+
return false;
13815+
}
1378913816
case GGML_OP_ARGMAX:
1379013817
case GGML_OP_COUNT_EQUAL:
1379113818
case GGML_OP_IM2COL:
@@ -14436,6 +14463,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1443614463
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
1443714464
} else if (tensor->op == GGML_OP_SUM_ROWS) {
1443814465
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
14466+
} else if (tensor->op == GGML_OP_CUMSUM) {
14467+
tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
1443914468
} else if (tensor->op == GGML_OP_MEAN) {
1444014469
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
1444114470
} else if (tensor->op == GGML_OP_ARGMAX) {
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#version 450
2+
3+
#include "types.glsl"
4+
#include "sum_rows.glsl"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
#extension GL_KHR_shader_subgroup_arithmetic : enable
8+
#extension GL_KHR_shader_subgroup_basic : enable
9+
10+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
11+
12+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
14+
15+
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
16+
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
17+
18+
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
19+
20+
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
21+
shared FLOAT_TYPE last_sum;
22+
23+
void main() {
24+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
25+
const uint tid = gl_LocalInvocationID.x;
26+
27+
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
28+
const uint i03_offset = i03 * p.ne01*p.ne02;
29+
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
30+
const uint i01 = row - i03_offset - i02*p.ne01;
31+
32+
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
33+
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
34+
35+
uint subgroup_id = tid / SUBGROUP_SIZE;
36+
37+
if (tid == 0) {
38+
last_sum = 0;
39+
}
40+
41+
uint col = tid;
42+
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
43+
for (int i = 0; i < num_iter; ++i) {
44+
FLOAT_TYPE v = 0;
45+
if (col < p.n_cols) {
46+
v = FLOAT_TYPE(data_a[src_idx + col]);
47+
}
48+
v = subgroupInclusiveAdd(v);
49+
50+
// Store the largest partial sum for each subgroup, then add the partials for all
51+
// lower subgroups and the final partial sum from the previous iteration.
52+
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
53+
partial[subgroup_id] = v;
54+
}
55+
barrier();
56+
for (int j = 0; j < subgroup_id; ++j) {
57+
v += partial[j];
58+
}
59+
v += last_sum;
60+
barrier();
61+
if (tid == BLOCK_SIZE - 1) {
62+
last_sum = v;
63+
}
64+
if (col < p.n_cols) {
65+
data_d[dst_idx + col] = D_TYPE(v);
66+
}
67+
col += BLOCK_SIZE;
68+
}
69+
}

ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#version 450
22

33
#include "types.glsl"
4+
#include "sum_rows.glsl"
45

56
#extension GL_EXT_control_flow_attributes : enable
67

@@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
1112

1213
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
1314

14-
layout (push_constant) uniform parameter
15-
{
16-
uint n_cols;
17-
uint ne01, ne02;
18-
uint nb01, nb02, nb03;
19-
uint nb11, nb12, nb13;
20-
float weight;
21-
uint misalign_offsets;
22-
uint ne0_12mp, ne0_12L;
23-
uint ne0_1mp, ne0_1L;
24-
} p;
25-
26-
uint get_aoffset() { return p.misalign_offsets >> 16; }
27-
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
28-
29-
// see init_fastdiv_values in ggml-vulkan.cpp
30-
uint fastdiv(uint n, uint mp, uint L) {
31-
uint msbs, lsbs;
32-
// msbs = mulhi(n, mp)
33-
umulExtended(n, mp, msbs, lsbs);
34-
return (msbs + n) >> L;
35-
}
36-
37-
3815
shared FLOAT_TYPE tmp[BLOCK_SIZE];
3916

4017
void main() {
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
// vk_op_sum_rows_push_constants
3+
layout (push_constant) uniform parameter
4+
{
5+
uint n_cols;
6+
uint ne01, ne02;
7+
uint nb01, nb02, nb03;
8+
uint nb11, nb12, nb13;
9+
float weight;
10+
uint misalign_offsets;
11+
uint ne0_12mp, ne0_12L;
12+
uint ne0_1mp, ne0_1L;
13+
} p;
14+
15+
uint get_aoffset() { return p.misalign_offsets >> 16; }
16+
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
17+
18+
// see init_fastdiv_values in ggml-vulkan.cpp
19+
uint fastdiv(uint n, uint mp, uint L) {
20+
uint msbs, lsbs;
21+
// msbs = mulhi(n, mp)
22+
umulExtended(n, mp, msbs, lsbs);
23+
return (msbs + n) >> L;
24+
}
25+

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,7 @@ void process_shaders() {
916916
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
917917
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
918918
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
919+
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
919920

920921
for (std::string dim_str : {"", "_3d"}) {
921922
for (bool bda : {false, true}) {

0 commit comments

Comments
 (0)