Skip to content

Commit 8a2e4f7

Browse files
committed
vulkan: fuse adds
Fuse adds that have the same shape, which are common in MoE models. It will currently fuse up to 6 adds, because we assume no more than 8 descriptors per dispatch. But this could be changed.
1 parent be48528 commit 8a2e4f7

File tree

6 files changed

+299
-25
lines changed

6 files changed

+299
-25
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 199 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
103103
struct ggml_backend_vk_context;
104104

105105
#define MAX_PARAMETER_COUNT 8
106+
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
107+
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2)
106108

107109
struct vk_pipeline_struct {
108110
std::string name;
@@ -368,6 +370,7 @@ struct vk_device_struct {
368370
bool float_controls_rte_fp16;
369371
bool subgroup_add;
370372
bool subgroup_shuffle;
373+
bool multi_add;
371374

372375
bool integer_dot_product;
373376

@@ -449,6 +452,9 @@ struct vk_device_struct {
449452
vk_pipeline pipeline_div[2][2][2];
450453
vk_pipeline pipeline_div_norepeat[2][2][2];
451454

455+
// indexed by num_additional_fused_ops == num_adds - 1
456+
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
457+
452458
vk_pipeline pipeline_add_id_f32;
453459

454460
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -800,6 +806,14 @@ struct vk_op_binary_push_constants {
800806
float param1; float param2; int32_t param3;
801807
};
802808

809+
struct vk_op_multi_add_push_constants {
810+
// shape for dst
811+
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
812+
813+
// strides for srcs+dst
814+
uint32_t nb[8][4];
815+
};
816+
803817
struct vk_op_add_id_push_constants {
804818
uint32_t ne0;
805819
uint32_t ne1;
@@ -3011,6 +3025,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
30113025
CREATE_BINARY(div, _norepeat, {1})
30123026
#undef CREATE_BINARY
30133027

3028+
if (device->multi_add) {
3029+
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
3030+
ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3031+
}
3032+
}
3033+
30143034
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
30153035

30163036
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3409,6 +3429,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
34093429
}
34103430
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
34113431

3432+
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
3433+
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
3434+
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
3435+
34123436
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
34133437
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
34143438

@@ -6892,6 +6916,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
68926916
switch (op) {
68936917
case GGML_OP_ADD:
68946918
{
6919+
if (ctx->num_additional_fused_ops > 0) {
6920+
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6921+
}
68956922
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
68966923
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
68976924
}
@@ -7739,6 +7766,107 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
77397766
}, dryrun);
77407767
}
77417768

7769+
static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
7770+
const ggml_tensor *first_node = cgraph->nodes[node_idx];
7771+
const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
7772+
7773+
// Make a list of all the tensors used by the op.
7774+
// Last element of the list is the dest tensor.
7775+
const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
7776+
uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
7777+
uint32_t num_tensors = num_srcs + 1;
7778+
GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);
7779+
7780+
tensors[0] = first_node->src[0];
7781+
tensors[1] = first_node->src[1];
7782+
for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
7783+
// check whether the previous result is src[0] or src[1]
7784+
if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
7785+
tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
7786+
} else {
7787+
tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
7788+
}
7789+
}
7790+
tensors[num_srcs] = dst;
7791+
7792+
vk_op_multi_add_push_constants pc;
7793+
pc.ne20 = (uint32_t)dst->ne[0];
7794+
pc.ne21 = (uint32_t)dst->ne[1];
7795+
pc.ne22 = (uint32_t)dst->ne[2];
7796+
pc.ne23 = (uint32_t)dst->ne[3];
7797+
7798+
for (uint32_t i = 0; i < num_tensors; ++i) {
7799+
const ggml_tensor *t = tensors[i];
7800+
pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
7801+
pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
7802+
pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
7803+
pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
7804+
}
7805+
7806+
vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
7807+
7808+
if (pipeline == nullptr) {
7809+
std::cerr << "ggml_vulkan: Error: Missing multi_add";
7810+
GGML_ABORT("fatal error");
7811+
}
7812+
7813+
if (dryrun) {
7814+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
7815+
return;
7816+
}
7817+
7818+
ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
7819+
vk_buffer buf[MAX_PARAMETER_COUNT];
7820+
size_t offset[MAX_PARAMETER_COUNT];
7821+
bool uma[MAX_PARAMETER_COUNT];
7822+
7823+
for (uint32_t i = 0; i < num_tensors; ++i) {
7824+
buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
7825+
buf[i] = nullptr;
7826+
offset[i] = 0;
7827+
uma[i] = false;
7828+
7829+
if (ctx->device->uma) {
7830+
ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
7831+
uma[i] = buf[i] != nullptr;
7832+
}
7833+
if (!uma[i]) {
7834+
buf[i] = buf_ctx[i]->dev_buffer;
7835+
offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
7836+
}
7837+
GGML_ASSERT(buf[i] != nullptr);
7838+
}
7839+
// If any remaining descriptors are unused, just point them at src[0]
7840+
for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
7841+
buf[i] = buf[0];
7842+
offset[i] = 0;
7843+
}
7844+
7845+
std::array<uint32_t, 3> elements;
7846+
7847+
uint32_t ne = ggml_nelements(dst);
7848+
if (ne > 262144) {
7849+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
7850+
} else if (ne > 512) {
7851+
elements = { 512, CEIL_DIV(ne, 512), 1 };
7852+
} else {
7853+
elements = { ne, 1, 1 };
7854+
}
7855+
7856+
ggml_vk_sync_buffers(subctx);
7857+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7858+
{
7859+
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
7860+
vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
7861+
vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
7862+
vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
7863+
vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
7864+
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
7865+
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
7866+
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
7867+
}, pc, elements);
7868+
}
7869+
77427870
static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
77437871
const uint32_t src0_type_size = ggml_type_size(src0->type);
77447872
const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -9692,8 +9820,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96929820

96939821
break;
96949822
case GGML_OP_ADD:
9695-
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
9696-
9823+
if (ctx->num_additional_fused_ops) {
9824+
ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun);
9825+
} else {
9826+
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
9827+
}
96979828
break;
96989829
case GGML_OP_SUB:
96999830
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -10570,6 +10701,58 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
1057010701
return true;
1057110702
}
1057210703

10704+
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
10705+
10706+
if (!ctx->device->multi_add) {
10707+
return 0;
10708+
}
10709+
10710+
const ggml_tensor *first_node = cgraph->nodes[node_idx];
10711+
if (first_node->op != GGML_OP_ADD) {
10712+
return 0;
10713+
}
10714+
10715+
int32_t num_adds = 1;
10716+
while (node_idx + num_adds < cgraph->n_nodes &&
10717+
cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
10718+
num_adds < MAX_FUSED_ADDS) {
10719+
num_adds++;
10720+
}
10721+
10722+
// The shader currently requires same shapes (but different strides are allowed),
10723+
// everything f32, and no misalignment
10724+
for (int32_t i = 0; i < num_adds; ++i) {
10725+
const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
10726+
if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
10727+
!ggml_are_same_shape(first_node, next_node->src[1]) ||
10728+
next_node->type != GGML_TYPE_F32 ||
10729+
next_node->src[0]->type != GGML_TYPE_F32 ||
10730+
next_node->src[1]->type != GGML_TYPE_F32 ||
10731+
get_misalign_bytes(ctx, next_node) ||
10732+
get_misalign_bytes(ctx, next_node->src[0]) ||
10733+
get_misalign_bytes(ctx, next_node->src[1])) {
10734+
num_adds = i;
10735+
}
10736+
}
10737+
10738+
// Verify we can fuse these
10739+
ggml_op adds[MAX_FUSED_ADDS];
10740+
for (int32_t i = 0; i < num_adds; ++i) {
10741+
adds[i] = GGML_OP_ADD;
10742+
}
10743+
10744+
// decrease num_adds if they can't all be fused
10745+
while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
10746+
num_adds--;
10747+
}
10748+
10749+
// a single add is not "fused", so just return zero
10750+
if (num_adds == 1) {
10751+
return 0;
10752+
}
10753+
return num_adds;
10754+
}
10755+
1057310756
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
1057410757
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
1057510758
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -10583,8 +10766,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1058310766

1058410767
uint64_t total_mat_mul_bytes = 0;
1058510768
for (int i = 0; i < cgraph->n_nodes; i++) {
10586-
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10587-
ctx->num_additional_fused_ops = 1;
10769+
if (!ctx->device->disable_fusion) {
10770+
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
10771+
if (num_adds) {
10772+
ctx->num_additional_fused_ops = num_adds - 1;
10773+
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10774+
ctx->num_additional_fused_ops = 1;
10775+
}
1058810776
}
1058910777
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
1059010778
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
@@ -10659,8 +10847,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1065910847
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
1066010848
}
1066110849

10662-
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10663-
ctx->num_additional_fused_ops = 1;
10850+
if (!ctx->device->disable_fusion) {
10851+
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
10852+
if (num_adds) {
10853+
ctx->num_additional_fused_ops = num_adds - 1;
10854+
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10855+
ctx->num_additional_fused_ops = 1;
10856+
}
1066410857
}
1066510858

1066610859
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)

ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#extension GL_EXT_control_flow_attributes : require
33

44
#include "rte.comp"
5+
#include "utils.comp"
56

67
layout (push_constant) uniform parameter
78
{
@@ -28,25 +29,9 @@ uint get_aoffset() { return p.misalign_offsets >> 16; }
2829
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
2930
uint get_doffset() { return p.misalign_offsets & 0xFF; }
3031

31-
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
32-
uint fastmod(uint a, uint b) {
33-
if ((b & (b-1)) == 0) {
34-
return a & (b-1);
35-
}
36-
return a % b;
37-
}
38-
39-
uint fastdiv(uint a, uint b) {
40-
return (a < b) ? 0 : (a / b);
41-
}
4232

4333
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
44-
i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
45-
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
46-
i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
47-
const uint i02_offset = i02*p.ne01*p.ne00;
48-
i01 = (idx - i03_offset - i02_offset) / p.ne00;
49-
i00 = idx - i03_offset - i02_offset - i01*p.ne00;
34+
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
5035
}
5136

5237
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_16bit_storage : require
4+
#extension GL_EXT_nonuniform_qualifier : enable
5+
#extension GL_EXT_control_flow_attributes : require
6+
7+
#include "rte.comp"
8+
#include "types.comp"
9+
#include "utils.comp"
10+
11+
layout (push_constant) uniform parameter2
12+
{
13+
// shape for dst
14+
uint ne20; uint ne21; uint ne22; uint ne23;
15+
16+
// strides for srcs+dst
17+
uint nb[8][4];
18+
} p;
19+
20+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
21+
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
22+
23+
layout(constant_id = 0) const uint num_srcs = 2;
24+
25+
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
26+
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
27+
}
28+
29+
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
30+
uint nb20 = p.nb[num_srcs][0];
31+
uint nb21 = p.nb[num_srcs][1];
32+
uint nb22 = p.nb[num_srcs][2];
33+
uint nb23 = p.nb[num_srcs][3];
34+
return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
35+
}
36+
37+
uint get_idx() {
38+
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
39+
}
40+
41+
const uint num_threads = 256;
42+
43+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
44+
45+
void main() {
46+
uint idx = get_idx();
47+
48+
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
49+
50+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
51+
const uint num_iter = 2;
52+
53+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
54+
if (idx >= ne) {
55+
continue;
56+
}
57+
uint i00, i01, i02, i03;
58+
get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);
59+
60+
FLOAT_TYPE sum = FLOAT_TYPE(0);
61+
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
62+
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
63+
}
64+
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
65+
66+
idx += num_threads;
67+
}
68+
}

0 commit comments

Comments
 (0)