Skip to content

Commit 97ec47a

Browse files
committed
vulkan: optimize rms_norm, and allow the work to spread across multiple SMs
There are really two parts to this change: (1) Some optimizations similar to what we have in soft_max, to unroll with different numbers of iterations. (2) A fusion optimization where we detect add followed by rms_norm, and make the add shader atomically accumulate the values^2 into memory. Then the rms_norm shader can just load that sum. This allows the rms_norm to be parallelized across multiple workgroups, it just becomes a simple per-element multiply. The fusion optimization is currently only applied when the rms_norm is on a single vector. This previously always ran on a single SM. It could apply more broadly, but when there are other dimensions the work can already spread across SMs, and there would be some complexity to tracking multiple atomic sums.
1 parent be48528 commit 97ec47a

File tree

5 files changed

+246
-51
lines changed

5 files changed

+246
-51
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 120 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ struct vk_device_struct {
369369
bool subgroup_add;
370370
bool subgroup_shuffle;
371371

372+
bool atomic_float_add;
373+
bool add_rms_fusion;
374+
uint32_t atomic_binding_alignment;
375+
372376
bool integer_dot_product;
373377

374378
bool subgroup_size_control;
@@ -448,6 +452,8 @@ struct vk_device_struct {
448452
vk_pipeline pipeline_mul_norepeat[2][2][2];
449453
vk_pipeline pipeline_div[2][2][2];
450454
vk_pipeline pipeline_div_norepeat[2][2][2];
455+
vk_pipeline pipeline_add_rms[2][2][2];
456+
vk_pipeline pipeline_add_rms_norepeat[2][2][2];
451457

452458
vk_pipeline pipeline_add_id_f32;
453459

@@ -1144,6 +1150,12 @@ class vk_perf_logger {
11441150
timings[name].push_back(time);
11451151
return;
11461152
}
1153+
if (node->op == GGML_OP_RMS_NORM) {
1154+
std::string name = ggml_op_name(node->op);
1155+
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
1156+
timings[name].push_back(time);
1157+
return;
1158+
}
11471159
timings[ggml_op_name(node->op)].push_back(time);
11481160
}
11491161
private:
@@ -1158,10 +1170,13 @@ struct ggml_backend_vk_context {
11581170

11591171
size_t semaphore_idx, event_idx;
11601172
ggml_vk_garbage_collector gc;
1161-
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
1162-
vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
1173+
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset;
1174+
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add;
11631175
vk::Fence fence, almost_ready_fence;
11641176
bool almost_ready_fence_pending {};
1177+
// Set before op_add and unset after op_rms_norm to indicate that the add should
1178+
// use atomics to accumulate the square of the vector components
1179+
bool do_add_rms_atomic;
11651180

11661181
vk_buffer buffer_pool[MAX_VK_BUFFERS];
11671182

@@ -2924,8 +2939,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29242939

29252940
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29262941
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2927-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2928-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2942+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2943+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
29292944
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29302945
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29312946

@@ -2995,20 +3010,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
29953010
};
29963011

29973012
bool rte = device->float_controls_rte_fp16;
2998-
#define CREATE_BINARY(name, namemod, spec) \
3013+
#define CREATE_BINARY(name, namemod, spec, bindings) \
29993014
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
30003015
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
30013016
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
3002-
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3003-
3004-
CREATE_BINARY(add, , {0})
3005-
CREATE_BINARY(add, _norepeat, {1})
3006-
CREATE_BINARY(sub, , {0})
3007-
CREATE_BINARY(sub, _norepeat, {1})
3008-
CREATE_BINARY(mul, , {0})
3009-
CREATE_BINARY(mul, _norepeat, {1})
3010-
CREATE_BINARY(div, , {0})
3011-
CREATE_BINARY(div, _norepeat, {1})
3017+
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3018+
3019+
CREATE_BINARY(add, , {0}, 4)
3020+
CREATE_BINARY(add, _norepeat, {1}, 4)
3021+
CREATE_BINARY(sub, , {0}, 3)
3022+
CREATE_BINARY(sub, _norepeat, {1}, 3)
3023+
CREATE_BINARY(mul, , {0}, 3)
3024+
CREATE_BINARY(mul, _norepeat, {1}, 3)
3025+
CREATE_BINARY(div, , {0}, 3)
3026+
CREATE_BINARY(div, _norepeat, {1}, 3)
3027+
CREATE_BINARY(add_rms, , {0}, 4)
3028+
CREATE_BINARY(add_rms, _norepeat, {1}, 4)
30123029
#undef CREATE_BINARY
30133030

30143031
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
@@ -3281,6 +3298,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
32813298
device->coopmat_support = false;
32823299
device->integer_dot_product = false;
32833300
bool bfloat16_support = false;
3301+
bool atomic_float_support = false;
32843302

32853303
for (const auto& properties : ext_props) {
32863304
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3320,6 +3338,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
33203338
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
33213339
bfloat16_support = true;
33223340
#endif
3341+
} else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3342+
atomic_float_support = true;
33233343
}
33243344
}
33253345

@@ -3536,6 +3556,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
35363556
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
35373557
}
35383558

3559+
VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3560+
atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3561+
if (atomic_float_support) {
3562+
last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3563+
last_struct = (VkBaseOutStructure *)&atomic_float_features;
3564+
device_extensions.push_back("VK_EXT_shader_atomic_float");
3565+
}
3566+
35393567
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
35403568

35413569
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3547,6 +3575,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
35473575
#endif
35483576

35493577
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3578+
device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
35503579

35513580
if (device->subgroup_size_control) {
35523581
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -3861,6 +3890,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
38613890

38623891
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
38633892

3893+
device->add_rms_fusion = !device->disable_fusion &&
3894+
device->subgroup_add &&
3895+
device->atomic_float_add;
3896+
device->atomic_binding_alignment =
3897+
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
3898+
38643899
return device;
38653900
}
38663901

@@ -6892,8 +6927,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
68926927
switch (op) {
68936928
case GGML_OP_ADD:
68946929
{
6895-
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6896-
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6930+
if (ctx->do_add_rms_atomic) {
6931+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
6932+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6933+
} else {
6934+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6935+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6936+
}
68976937
}
68986938
case GGML_OP_SUB:
68996939
{
@@ -7494,7 +7534,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
74947534
}
74957535
} break;
74967536
case GGML_OP_RMS_NORM:
7497-
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7537+
if (ctx->do_add_rms_atomic) {
7538+
// Run one element per thread, 128 threads per workgroup
7539+
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7540+
} else {
7541+
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7542+
}
74987543
break;
74997544

75007545
case GGML_OP_SUM:
@@ -7642,7 +7687,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
76427687
}
76437688
}
76447689

7645-
if (op == GGML_OP_GLU) {
7690+
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7691+
vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7692+
size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7693+
ggml_vk_sync_buffers(subctx);
7694+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7695+
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
7696+
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
7697+
vk_subbuffer{ d_D, d_buf_offset, d_sz },
7698+
vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
7699+
}, pc, elements);
7700+
} else if (op == GGML_OP_GLU) {
76467701
// Empty src1 is possible in glu, but the shader needs a buffer
76477702
vk_subbuffer subbuf_y;
76487703
if (use_src1) {
@@ -7750,7 +7805,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
77507805
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
77517806
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
77527807
0,
7753-
0.0f, 0.0f, 0,
7808+
0.0f, 0.0f, ctx->do_add_rms_atomic,
77547809
}, dryrun);
77557810
}
77567811

@@ -8213,8 +8268,13 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
82138268
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
82148269
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
82158270
0,
8216-
op_params[0], 0.0f, 0,
8271+
op_params[0], 0.0f, ctx->do_add_rms_atomic,
82178272
}, dryrun);
8273+
8274+
if (ctx->do_add_rms_atomic) {
8275+
ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment;
8276+
ctx->do_add_rms_atomic = false;
8277+
}
82188278
}
82198279

82208280
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9492,6 +9552,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
94929552
}
94939553
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
94949554
}
9555+
if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) {
9556+
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")");
9557+
// Resize buffer
9558+
if (ctx->prealloc_atomic_add != nullptr) {
9559+
ggml_vk_destroy_buffer(ctx->prealloc_atomic_add);
9560+
}
9561+
ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add);
9562+
}
94959563
}
94969564

94979565
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9547,10 +9615,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
95479615
return false;
95489616
}
95499617
break;
9618+
case GGML_OP_ADD:
9619+
if (node_idx + 1 < cgraph->n_nodes &&
9620+
cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
9621+
cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
9622+
ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9623+
ctx->device->add_rms_fusion) {
9624+
if (dryrun) {
9625+
ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment;
9626+
}
9627+
ctx->do_add_rms_atomic = true;
9628+
}
9629+
break;
95509630
case GGML_OP_REPEAT:
95519631
case GGML_OP_REPEAT_BACK:
95529632
case GGML_OP_GET_ROWS:
9553-
case GGML_OP_ADD:
95549633
case GGML_OP_ADD_ID:
95559634
case GGML_OP_ACC:
95569635
case GGML_OP_SUB:
@@ -9667,6 +9746,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96679746
// do the only thing needed for the dryrun.
96689747
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
96699748
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9749+
if (node->op == GGML_OP_RMS_NORM) {
9750+
ctx->do_add_rms_atomic = false;
9751+
}
96709752
return false;
96719753
}
96729754
default:
@@ -10581,6 +10663,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1058110663
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
1058210664
}
1058310665

10666+
ctx->prealloc_size_atomic_add = 0;
10667+
ctx->prealloc_size_atomic_add_offset = 0;
10668+
ctx->do_add_rms_atomic = false;
10669+
1058410670
uint64_t total_mat_mul_bytes = 0;
1058510671
for (int i = 0; i < cgraph->n_nodes; i++) {
1058610672
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
@@ -10641,6 +10727,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1064110727
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
1064210728
}
1064310729

10730+
if (ctx->prealloc_size_atomic_add) {
10731+
if (ctx->compute_ctx.expired()) {
10732+
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10733+
ctx->compute_ctx = compute_ctx;
10734+
ggml_vk_ctx_begin(ctx->device, compute_ctx);
10735+
} else {
10736+
compute_ctx = ctx->compute_ctx.lock();
10737+
}
10738+
// initialize atomic sums to zero.
10739+
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add);
10740+
}
10741+
1064410742
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
1064510743
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
1064610744
// (and scaled down based on model size, so smaller models submit earlier).
Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
#version 450
22

33
#extension GL_EXT_shader_16bit_storage : require
4+
#if ADD_RMS
5+
#extension GL_EXT_shader_atomic_float : enable
6+
#extension GL_KHR_shader_subgroup_arithmetic : enable
7+
#extension GL_KHR_shader_subgroup_basic : enable
8+
#endif
49

510
#include "types.comp"
611
#include "generic_binary_head.comp"
712

813
const uint num_threads = 256;
914

15+
layout (binding = 3) buffer AtomBuf {float data_atom;};
16+
1017
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
1118

1219
void main() {
@@ -15,15 +22,29 @@ void main() {
1522
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
1623
const uint num_iter = 2;
1724

25+
FLOAT_TYPE sum_sq = 0;
26+
1827
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
1928
if (idx >= p.ne) {
2029
continue;
2130
}
2231
uint i00, i01, i02, i03;
2332
get_indices(idx, i00, i01, i02, i03);
2433

25-
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
34+
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
35+
sum_sq += sum*sum;
36+
37+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
2638

2739
idx += num_threads;
2840
}
41+
42+
#if ADD_RMS
43+
if (p.param3 != 0) {
44+
sum_sq = subgroupAdd(sum_sq);
45+
if (sum_sq != 0 && gl_SubgroupInvocationID == 0) {
46+
atomicAdd(data_atom, sum_sq);
47+
}
48+
}
49+
#endif
2950
}

0 commit comments

Comments
 (0)