Skip to content

vulkan: optimize rms_norm, and allow the work to spread across multiple SMs #15281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 132 additions & 23 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ struct vk_device_struct {
bool subgroup_add;
bool subgroup_shuffle;

bool add_rms_fusion;
uint32_t partials_binding_alignment;

bool integer_dot_product;

bool subgroup_size_control;
Expand Down Expand Up @@ -448,6 +451,8 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_norepeat[2][2][2];
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];
vk_pipeline pipeline_add_rms[2][2][2];
vk_pipeline pipeline_add_rms_norepeat[2][2][2];

vk_pipeline pipeline_add_id_f32;

Expand All @@ -470,6 +475,8 @@ struct vk_device_struct {
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_rms_norm_mul_f32;
vk_pipeline pipeline_rms_norm_partials_f32;
vk_pipeline pipeline_rms_norm_mul_partials_f32;
vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;

Expand Down Expand Up @@ -1144,6 +1151,12 @@ class vk_perf_logger {
timings[name].push_back(time);
return;
}
if (node->op == GGML_OP_RMS_NORM) {
std::string name = ggml_op_name(node->op);
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
timings[name].push_back(time);
return;
}
timings[ggml_op_name(node->op)].push_back(time);
}
private:
Expand All @@ -1158,10 +1171,13 @@ struct ggml_backend_vk_context {

size_t semaphore_idx, event_idx;
ggml_vk_garbage_collector gc;
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials;
vk::Fence fence, almost_ready_fence;
bool almost_ready_fence_pending {};
// Set before op_add and unset after op_rms_norm to indicate that the add should
// write partial sums to accumulate the square of the vector components
bool do_add_rms_partials;

vk_buffer buffer_pool[MAX_VK_BUFFERS];

Expand Down Expand Up @@ -2924,8 +2940,12 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);

ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);

ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);

Expand Down Expand Up @@ -2995,20 +3015,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
};

bool rte = device->float_controls_rte_fp16;
#define CREATE_BINARY(name, namemod, spec) \
#define CREATE_BINARY(name, namemod, spec, bindings) \
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);

CREATE_BINARY(add, , {0})
CREATE_BINARY(add, _norepeat, {1})
CREATE_BINARY(sub, , {0})
CREATE_BINARY(sub, _norepeat, {1})
CREATE_BINARY(mul, , {0})
CREATE_BINARY(mul, _norepeat, {1})
CREATE_BINARY(div, , {0})
CREATE_BINARY(div, _norepeat, {1})
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);

CREATE_BINARY(add, , {0}, 4)
CREATE_BINARY(add, _norepeat, {1}, 4)
CREATE_BINARY(sub, , {0}, 3)
CREATE_BINARY(sub, _norepeat, {1}, 3)
CREATE_BINARY(mul, , {0}, 3)
CREATE_BINARY(mul, _norepeat, {1}, 3)
CREATE_BINARY(div, , {0}, 3)
CREATE_BINARY(div, _norepeat, {1}, 3)
CREATE_BINARY(add_rms, , {0}, 4)
CREATE_BINARY(add_rms, _norepeat, {1}, 4)
#undef CREATE_BINARY

ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
Expand Down Expand Up @@ -3861,6 +3883,11 @@ static vk_device ggml_vk_get_device(size_t idx) {

device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;

device->add_rms_fusion = !device->disable_fusion &&
device->subgroup_add;
device->partials_binding_alignment =
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);

return device;
}

Expand Down Expand Up @@ -6892,8 +6919,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
switch (op) {
case GGML_OP_ADD:
{
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
if (ctx->do_add_rms_partials) {
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
} else {
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
}
}
case GGML_OP_SUB:
{
Expand Down Expand Up @@ -7011,7 +7043,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_RMS_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
if (ctx->do_add_rms_partials) {
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
} else {
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
}
}
return nullptr;
case GGML_OP_RMS_NORM_BACK:
Expand Down Expand Up @@ -7494,7 +7530,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
} break;
case GGML_OP_RMS_NORM:
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
if (ctx->do_add_rms_partials) {
// Run one element per thread, 128 threads per workgroup
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
} else {
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
}
break;

case GGML_OP_SUM:
Expand Down Expand Up @@ -7642,7 +7683,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}

if (op == GGML_OP_GLU) {
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
vk_buffer d_A = ctx->prealloc_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
size_t a_buf_offset = ctx->prealloc_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
vk_subbuffer{ d_D, d_buf_offset, d_sz },
vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
}, pc, elements);
} else if (op == GGML_OP_GLU) {
// Empty src1 is possible in glu, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
Expand Down Expand Up @@ -7750,7 +7801,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
0.0f, 0.0f, ctx->do_add_rms_partials,
}, dryrun);
}

Expand Down Expand Up @@ -8202,19 +8253,39 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}

static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
const uint32_t ne = (uint32_t)node->ne[0];
const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
const uint32_t num_partials = CEIL_DIV(ne, denom);
return num_partials;
}

static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
return num_bytes;
}

static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);

uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;

ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], 0.0f, 0,
op_params[0], 0.0f, (int32_t)param3,
}, dryrun);

if (ctx->do_add_rms_partials) {
ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
ctx->do_add_rms_partials = false;
}
}

static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
Expand Down Expand Up @@ -9492,6 +9563,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
}
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
}
if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
// Resize buffer
if (ctx->prealloc_add_rms_partials != nullptr) {
ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
}
ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
}
}

static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
Expand Down Expand Up @@ -9547,10 +9626,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
return false;
}
break;
case GGML_OP_ADD:
if (node_idx + 1 < cgraph->n_nodes &&
cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
ctx->device->add_rms_fusion) {
if (dryrun) {
ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
}
ctx->do_add_rms_partials = true;
}
break;
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ACC:
case GGML_OP_SUB:
Expand Down Expand Up @@ -9667,6 +9757,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// do the only thing needed for the dryrun.
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (node->op == GGML_OP_RMS_NORM) {
ctx->do_add_rms_partials = false;
}
return false;
}
default:
Expand Down Expand Up @@ -10581,6 +10674,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
}

ctx->prealloc_size_add_rms_partials = 0;
ctx->prealloc_size_add_rms_partials_offset = 0;
ctx->do_add_rms_partials = false;

uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
Expand Down Expand Up @@ -10641,6 +10738,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
}

if (ctx->prealloc_size_add_rms_partials) {
if (ctx->compute_ctx.expired()) {
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->compute_ctx = compute_ctx;
ggml_vk_ctx_begin(ctx->device, compute_ctx);
} else {
compute_ctx = ctx->compute_ctx.lock();
}
// initialize partial sums to zero.
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
}

// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
// (and scaled down based on model size, so smaller models submit earlier).
Expand Down
42 changes: 41 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/add.comp
Original file line number Diff line number Diff line change
@@ -1,29 +1,69 @@
#version 450

#extension GL_EXT_shader_16bit_storage : require
#if ADD_RMS
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif

#include "types.comp"
#include "generic_binary_head.comp"

const uint num_threads = 256;

layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};

layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;

#if ADD_RMS
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
shared FLOAT_TYPE sumsh[num_threads];
#endif

void main() {
uint idx = get_idx();
uint orig_idx = idx;

// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;

FLOAT_TYPE sum_sq = 0;

[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);

data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
sum_sq += sum*sum;

data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);

idx += num_threads;
}

#if ADD_RMS
if (p.param3 != 0) {
// reduce the sum within each subgroup, then across subgroups
const uint NumSubgroups = num_threads / gl_SubgroupSize;
sum_sq = subgroupAdd(sum_sq);
if (gl_SubgroupInvocationID == 0) {
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
sum_sq += sumsh[gl_SubgroupID + s];
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
}

if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
}
}
#endif
}
Loading
Loading