Skip to content

Commit c523636

Browse files
committed
complete rebase against fused adds - multi_add shader can also compute partial sums
1 parent e4ec524 commit c523636

File tree

4 files changed

+89
-23
lines changed

4 files changed

+89
-23
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
102102

103103
struct ggml_backend_vk_context;
104104

105-
#define MAX_PARAMETER_COUNT 8
105+
#define MAX_PARAMETER_COUNT 12
106106
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
107-
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2)
107+
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
108108

109109
struct vk_pipeline_struct {
110110
std::string name;
@@ -459,6 +459,7 @@ struct vk_device_struct {
459459

460460
// indexed by num_additional_fused_ops == num_adds - 1
461461
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
462+
vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
462463

463464
vk_pipeline pipeline_add_id_f32;
464465

@@ -819,8 +820,13 @@ struct vk_op_multi_add_push_constants {
819820
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
820821

821822
// strides for srcs+dst
822-
uint32_t nb[8][4];
823+
uint32_t nb[MAX_PARAMETER_COUNT][4];
824+
825+
uint32_t rms_partials;
823826
};
827+
// update multi_add.comp if this changes
828+
static_assert(MAX_PARAMETER_COUNT == 12);
829+
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
824830

825831
struct vk_op_add_id_push_constants {
826832
uint32_t ne0;
@@ -3032,7 +3038,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
30323038

30333039
if (device->multi_add) {
30343040
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
3035-
ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3041+
ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3042+
ggml_vk_create_pipeline(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
30363043
}
30373044
}
30383045

@@ -6912,7 +6919,7 @@ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst)
69126919
return elements;
69136920
}
69146921

6915-
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
6922+
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
69166923
switch (op) {
69176924
case GGML_OP_GET_ROWS:
69186925
GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -7836,7 +7843,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
78367843
const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
78377844
uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
78387845
uint32_t num_tensors = num_srcs + 1;
7839-
GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);
7846+
GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
78407847

78417848
tensors[0] = first_node->src[0];
78427849
tensors[1] = first_node->src[1];
@@ -7863,8 +7870,9 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
78637870
pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
78647871
pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
78657872
}
7873+
pc.rms_partials = ctx->do_add_rms_partials;
78667874

7867-
vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
7875+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);
78687876

78697877
if (pipeline == nullptr) {
78707878
std::cerr << "ggml_vulkan: Error: Missing multi_add";
@@ -7902,6 +7910,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
79027910
buf[i] = buf[0];
79037911
offset[i] = 0;
79047912
}
7913+
if (ctx->do_add_rms_partials) {
7914+
buf[num_tensors] = ctx->prealloc_add_rms_partials;
7915+
offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
7916+
}
79057917

79067918
std::array<uint32_t, 3> elements;
79077919

@@ -7915,6 +7927,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
79157927
}
79167928

79177929
ggml_vk_sync_buffers(subctx);
7930+
static_assert(MAX_PARAMETER_COUNT == 12);
79187931
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
79197932
{
79207933
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
@@ -7925,6 +7938,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
79257938
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
79267939
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
79277940
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
7941+
vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
7942+
vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
7943+
vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
7944+
vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
79287945
}, pc, elements);
79297946
}
79307947

@@ -9771,17 +9788,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97719788
}
97729789
break;
97739790
case GGML_OP_ADD:
9774-
if (node_idx + 1 < cgraph->n_nodes &&
9775-
cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
9776-
cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
9777-
ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9778-
ctx->device->add_rms_fusion) {
9779-
if (dryrun) {
9780-
ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
9791+
{
9792+
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
9793+
if (next_node_idx < cgraph->n_nodes &&
9794+
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
9795+
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
9796+
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
9797+
ctx->device->add_rms_fusion) {
9798+
if (dryrun) {
9799+
ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
9800+
}
9801+
ctx->do_add_rms_partials = true;
97819802
}
9782-
ctx->do_add_rms_partials = true;
9783-
}
9784-
break;
9803+
} break;
97859804
case GGML_OP_REPEAT:
97869805
case GGML_OP_REPEAT_BACK:
97879806
case GGML_OP_GET_ROWS:

ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#extension GL_EXT_shader_16bit_storage : require
44
#extension GL_EXT_nonuniform_qualifier : enable
55
#extension GL_EXT_control_flow_attributes : require
6+
#if ADD_RMS
7+
#extension GL_KHR_shader_subgroup_arithmetic : enable
8+
#extension GL_KHR_shader_subgroup_basic : enable
9+
#endif
610

711
#include "rte.comp"
812
#include "types.comp"
@@ -14,12 +18,16 @@ layout (push_constant) uniform parameter2
1418
uint ne20; uint ne21; uint ne22; uint ne23;
1519

1620
// strides for srcs+dst
17-
uint nb[8][4];
21+
uint nb[12][4];
22+
23+
uint rms_partials;
1824
} p;
1925

2026
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
2127
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
2228

29+
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
30+
2331
layout(constant_id = 0) const uint num_srcs = 2;
2432

2533
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
@@ -42,14 +50,22 @@ const uint num_threads = 256;
4250

4351
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
4452

53+
#if ADD_RMS
54+
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
55+
shared FLOAT_TYPE sumsh[num_threads];
56+
#endif
57+
4558
void main() {
4659
uint idx = get_idx();
60+
uint orig_idx = idx;
4761

4862
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
4963

5064
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
5165
const uint num_iter = 2;
5266

67+
FLOAT_TYPE sum_sq = 0;
68+
5369
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
5470
if (idx >= ne) {
5571
continue;
@@ -61,8 +77,32 @@ void main() {
6177
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
6278
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
6379
}
80+
sum_sq += sum*sum;
6481
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
6582

6683
idx += num_threads;
6784
}
85+
86+
#if ADD_RMS
87+
if (p.rms_partials != 0) {
88+
// reduce the sum within each subgroup, then across subgroups
89+
const uint NumSubgroups = num_threads / gl_SubgroupSize;
90+
sum_sq = subgroupAdd(sum_sq);
91+
if (gl_SubgroupInvocationID == 0) {
92+
sumsh[gl_SubgroupID] = sum_sq;
93+
}
94+
barrier();
95+
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
96+
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
97+
sum_sq += sumsh[gl_SubgroupID + s];
98+
sumsh[gl_SubgroupID] = sum_sq;
99+
}
100+
barrier();
101+
}
102+
103+
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
104+
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
105+
}
106+
}
107+
#endif
68108
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,8 @@ void process_shaders() {
680680

681681
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
682682

683-
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
683+
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
684+
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
684685

685686
for (auto &c : compiles) {
686687
c.wait();

tests/test-backend-ops.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,6 +2858,7 @@ struct test_rms_norm_mul_add : public test_case {
28582858
const std::array<int64_t, 4> ne;
28592859
const float eps;
28602860
const bool broadcast;
2861+
const bool multi_add; // test a sequence of adds feeding into rms_norm
28612862

28622863
std::string op_desc(ggml_tensor * t) override {
28632864
GGML_UNUSED(t);
@@ -2867,13 +2868,13 @@ struct test_rms_norm_mul_add : public test_case {
28672868
bool run_whole_graph() override { return true; }
28682869

28692870
std::string vars() override {
2870-
return VARS_TO_STR4(type, ne, eps, broadcast);
2871+
return VARS_TO_STR5(type, ne, eps, broadcast, multi_add);
28712872
}
28722873

28732874
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
28742875
std::array<int64_t, 4> ne = {64, 5, 4, 3},
2875-
float eps = 1e-6f, bool broadcast = false)
2876-
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}
2876+
float eps = 1e-6f, bool broadcast = false, bool multi_add = false)
2877+
: type(type), ne(ne), eps(eps), broadcast(broadcast), multi_add(multi_add) {}
28772878

28782879
ggml_tensor * build_graph(ggml_context * ctx) override {
28792880
std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
@@ -2891,6 +2892,9 @@ struct test_rms_norm_mul_add : public test_case {
28912892

28922893
// Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
28932894
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
2895+
if (multi_add) {
2896+
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
2897+
}
28942898
ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);
28952899
ggml_set_name(out, "out");
28962900

@@ -5679,7 +5683,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
56795683
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
56805684
}
56815685
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
5682-
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f));
5686+
for (bool multi_add : {false, true}) {
5687+
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
5688+
}
56835689
}
56845690

56855691
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));

0 commit comments

Comments
 (0)