Skip to content

vulkan: fuse adds #15252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 200 additions & 6 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
struct ggml_backend_vk_context;

#define MAX_PARAMETER_COUNT 8
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2)

struct vk_pipeline_struct {
std::string name;
Expand Down Expand Up @@ -368,6 +370,7 @@ struct vk_device_struct {
bool float_controls_rte_fp16;
bool subgroup_add;
bool subgroup_shuffle;
bool multi_add;

bool integer_dot_product;

Expand Down Expand Up @@ -449,6 +452,9 @@ struct vk_device_struct {
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];

// indexed by num_additional_fused_ops == num_adds - 1
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];

vk_pipeline pipeline_add_id_f32;

vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
Expand Down Expand Up @@ -800,6 +806,14 @@ struct vk_op_binary_push_constants {
float param1; float param2; int32_t param3;
};

struct vk_op_multi_add_push_constants {
// shape for dst
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;

// strides for srcs+dst
uint32_t nb[8][4];
};

struct vk_op_add_id_push_constants {
uint32_t ne0;
uint32_t ne1;
Expand Down Expand Up @@ -3011,6 +3025,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_BINARY(div, _norepeat, {1})
#undef CREATE_BINARY

if (device->multi_add) {
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
}
}

ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
Expand Down Expand Up @@ -3548,6 +3568,11 @@ static vk_device ggml_vk_get_device(size_t idx) {

device->pipeline_robustness = pl_robustness_features.pipelineRobustness;

device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
vk12_features.runtimeDescriptorArray &&
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;

if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
Expand Down Expand Up @@ -6892,6 +6917,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
switch (op) {
case GGML_OP_ADD:
{
if (ctx->num_additional_fused_ops > 0) {
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
}
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
}
Expand Down Expand Up @@ -7739,6 +7767,107 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}

static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
const ggml_tensor *first_node = cgraph->nodes[node_idx];
const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];

// Make a list of all the tensors used by the op.
// Last element of the list is the dest tensor.
const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
uint32_t num_tensors = num_srcs + 1;
GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);

tensors[0] = first_node->src[0];
tensors[1] = first_node->src[1];
for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
// check whether the previous result is src[0] or src[1]
if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
} else {
tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
}
}
tensors[num_srcs] = dst;

vk_op_multi_add_push_constants pc;
pc.ne20 = (uint32_t)dst->ne[0];
pc.ne21 = (uint32_t)dst->ne[1];
pc.ne22 = (uint32_t)dst->ne[2];
pc.ne23 = (uint32_t)dst->ne[3];

for (uint32_t i = 0; i < num_tensors; ++i) {
const ggml_tensor *t = tensors[i];
pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
}

vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];

if (pipeline == nullptr) {
std::cerr << "ggml_vulkan: Error: Missing multi_add";
GGML_ABORT("fatal error");
}

if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}

ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
vk_buffer buf[MAX_PARAMETER_COUNT];
size_t offset[MAX_PARAMETER_COUNT];
bool uma[MAX_PARAMETER_COUNT];

for (uint32_t i = 0; i < num_tensors; ++i) {
buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
buf[i] = nullptr;
offset[i] = 0;
uma[i] = false;

if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
uma[i] = buf[i] != nullptr;
}
if (!uma[i]) {
buf[i] = buf_ctx[i]->dev_buffer;
offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
}
GGML_ASSERT(buf[i] != nullptr);
}
// If any remaining descriptors are unused, just point them at src[0]
for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
buf[i] = buf[0];
offset[i] = 0;
}

std::array<uint32_t, 3> elements;

uint32_t ne = ggml_nelements(dst);
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
elements = { 512, CEIL_DIV(ne, 512), 1 };
} else {
elements = { ne, 1, 1 };
}

ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
}, pc, elements);
}

static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
Expand Down Expand Up @@ -9692,8 +9821,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr

break;
case GGML_OP_ADD:
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);

if (ctx->num_additional_fused_ops) {
ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun);
} else {
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
}
break;
case GGML_OP_SUB:
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
Expand Down Expand Up @@ -10570,6 +10702,58 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
return true;
}

static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {

if (!ctx->device->multi_add) {
return 0;
}

const ggml_tensor *first_node = cgraph->nodes[node_idx];
if (first_node->op != GGML_OP_ADD) {
return 0;
}

int32_t num_adds = 1;
while (node_idx + num_adds < cgraph->n_nodes &&
cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
num_adds < MAX_FUSED_ADDS) {
num_adds++;
}

// The shader currently requires same shapes (but different strides are allowed),
// everything f32, and no misalignment
for (int32_t i = 0; i < num_adds; ++i) {
const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
!ggml_are_same_shape(first_node, next_node->src[1]) ||
next_node->type != GGML_TYPE_F32 ||
next_node->src[0]->type != GGML_TYPE_F32 ||
next_node->src[1]->type != GGML_TYPE_F32 ||
get_misalign_bytes(ctx, next_node) ||
get_misalign_bytes(ctx, next_node->src[0]) ||
get_misalign_bytes(ctx, next_node->src[1])) {
num_adds = i;
}
}

// Verify we can fuse these
ggml_op adds[MAX_FUSED_ADDS];
for (int32_t i = 0; i < num_adds; ++i) {
adds[i] = GGML_OP_ADD;
}

// decrease num_adds if they can't all be fused
while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
num_adds--;
}

// a single add is not "fused", so just return zero
if (num_adds == 1) {
return 0;
}
return num_adds;
}

static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
Expand All @@ -10583,8 +10767,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg

uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
if (!ctx->device->disable_fusion) {
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
if (num_adds) {
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
Expand Down Expand Up @@ -10659,8 +10848,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}

if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
if (!ctx->device->disable_fusion) {
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
if (num_adds) {
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
}

// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
Expand Down
19 changes: 2 additions & 17 deletions ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#extension GL_EXT_control_flow_attributes : require

#include "rte.comp"
#include "utils.comp"

layout (push_constant) uniform parameter
{
Expand All @@ -28,25 +29,9 @@ uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
uint get_doffset() { return p.misalign_offsets & 0xFF; }

// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
uint fastmod(uint a, uint b) {
if ((b & (b-1)) == 0) {
return a & (b-1);
}
return a % b;
}

uint fastdiv(uint a, uint b) {
return (a < b) ? 0 : (a / b);
}

void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
const uint i02_offset = i02*p.ne01*p.ne00;
i01 = (idx - i03_offset - i02_offset) / p.ne00;
i00 = idx - i03_offset - i02_offset - i01*p.ne00;
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
}

uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
Expand Down
68 changes: 68 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#version 450

#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_nonuniform_qualifier : enable
#extension GL_EXT_control_flow_attributes : require

#include "rte.comp"
#include "types.comp"
#include "utils.comp"

layout (push_constant) uniform parameter2
{
// shape for dst
uint ne20; uint ne21; uint ne22; uint ne23;

// strides for srcs+dst
uint nb[8][4];
} p;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
Comment on lines +20 to +21
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this not require any support checks on the API side? There seems to be an endless source of obscure GLSL extensions that solve very specific issues. Neat.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It requires the runtimeDescriptorArray feature from the descriptor_indexing extension. This is required to be supported in Vulkan 1.2 so I don't think the check is strictly required, but I'll add it since it's simple.


layout(constant_id = 0) const uint num_srcs = 2;

uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
}

uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
uint nb20 = p.nb[num_srcs][0];
uint nb21 = p.nb[num_srcs][1];
uint nb22 = p.nb[num_srcs][2];
uint nb23 = p.nb[num_srcs][3];
return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
}

uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}

const uint num_threads = 256;

layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;

void main() {
uint idx = get_idx();

uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;

// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;

[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);

FLOAT_TYPE sum = FLOAT_TYPE(0);
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
}
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);

idx += num_threads;
}
}
Loading
Loading