Skip to content

Commit e4ec524

Browse files
committed
Change add+rms_norm optimization to write out an array of partial sums
rather than using atomic add, to make it deterministic. The rms_norm shader fetches a subgroup's worth in parallel and uses subgroupAdd to add them up.
1 parent d6da225 commit e4ec524

File tree

6 files changed

+188
-125
lines changed

6 files changed

+188
-125
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,8 @@ struct vk_device_struct {
372372
bool subgroup_shuffle;
373373
bool multi_add;
374374

375-
bool atomic_float_add;
376375
bool add_rms_fusion;
377-
uint32_t atomic_binding_alignment;
376+
uint32_t partials_binding_alignment;
378377

379378
bool integer_dot_product;
380379

@@ -482,6 +481,8 @@ struct vk_device_struct {
482481
vk_pipeline pipeline_group_norm_f32;
483482
vk_pipeline pipeline_rms_norm_f32;
484483
vk_pipeline pipeline_rms_norm_mul_f32;
484+
vk_pipeline pipeline_rms_norm_partials_f32;
485+
vk_pipeline pipeline_rms_norm_mul_partials_f32;
485486
vk_pipeline pipeline_rms_norm_back_f32;
486487
vk_pipeline pipeline_l2_norm_f32;
487488

@@ -1191,13 +1192,13 @@ struct ggml_backend_vk_context {
11911192

11921193
size_t semaphore_idx, event_idx;
11931194
ggml_vk_garbage_collector gc;
1194-
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset;
1195-
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add;
1195+
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
1196+
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials;
11961197
vk::Fence fence, almost_ready_fence;
11971198
bool almost_ready_fence_pending {};
11981199
// Set before op_add and unset after op_rms_norm to indicate that the add should
1199-
// use atomics to accumulate the square of the vector components
1200-
bool do_add_rms_atomic;
1200+
// write partial sums to accumulate the square of the vector components
1201+
bool do_add_rms_partials;
12011202

12021203
vk_buffer buffer_pool[MAX_VK_BUFFERS];
12031204

@@ -2936,8 +2937,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
29362937

29372938
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29382939
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2940+
29392941
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
29402942
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2943+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2944+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2945+
29412946
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29422947
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29432948

@@ -3303,7 +3308,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
33033308
device->coopmat_support = false;
33043309
device->integer_dot_product = false;
33053310
bool bfloat16_support = false;
3306-
bool atomic_float_support = false;
33073311

33083312
for (const auto& properties : ext_props) {
33093313
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3343,8 +3347,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
33433347
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
33443348
bfloat16_support = true;
33453349
#endif
3346-
} else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3347-
atomic_float_support = true;
33483350
}
33493351
}
33503352

@@ -3561,14 +3563,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
35613563
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
35623564
}
35633565

3564-
VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3565-
atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3566-
if (atomic_float_support) {
3567-
last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3568-
last_struct = (VkBaseOutStructure *)&atomic_float_features;
3569-
device_extensions.push_back("VK_EXT_shader_atomic_float");
3570-
}
3571-
35723566
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
35733567

35743568
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3580,7 +3574,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
35803574
#endif
35813575

35823576
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3583-
device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
35843577

35853578
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
35863579
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
@@ -3902,9 +3895,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
39023895
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
39033896

39043897
device->add_rms_fusion = !device->disable_fusion &&
3905-
device->subgroup_add &&
3906-
device->atomic_float_add;
3907-
device->atomic_binding_alignment =
3898+
device->subgroup_add;
3899+
device->partials_binding_alignment =
39083900
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
39093901

39103902
return device;
@@ -6949,13 +6941,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69496941
case GGML_OP_ADD:
69506942
{
69516943
if (ctx->num_additional_fused_ops > 0) {
6952-
if (ctx->do_add_rms_atomic) {
6944+
if (ctx->do_add_rms_partials) {
69536945
return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
69546946
} else {
69556947
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
69566948
}
69576949
}
6958-
if (ctx->do_add_rms_atomic) {
6950+
if (ctx->do_add_rms_partials) {
69596951
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
69606952
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
69616953
} else {
@@ -7079,7 +7071,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
70797071
return nullptr;
70807072
case GGML_OP_RMS_NORM:
70817073
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7082-
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7074+
if (ctx->do_add_rms_partials) {
7075+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
7076+
} else {
7077+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7078+
}
70837079
}
70847080
return nullptr;
70857081
case GGML_OP_RMS_NORM_BACK:
@@ -7567,7 +7563,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
75677563
}
75687564
} break;
75697565
case GGML_OP_RMS_NORM:
7570-
if (ctx->do_add_rms_atomic) {
7566+
if (ctx->do_add_rms_partials) {
75717567
// Run one element per thread, 128 threads per workgroup
75727568
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
75737569
} else {
@@ -7721,8 +7717,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
77217717
}
77227718

77237719
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7724-
vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7725-
size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7720+
vk_buffer d_A = ctx->prealloc_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
7721+
size_t a_buf_offset = ctx->prealloc_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
77267722
ggml_vk_sync_buffers(subctx);
77277723
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
77287724
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
@@ -7943,7 +7939,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
79437939
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
79447940
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
79457941
0,
7946-
0.0f, 0.0f, ctx->do_add_rms_atomic,
7942+
0.0f, 0.0f, ctx->do_add_rms_partials,
79477943
}, dryrun);
79487944
}
79497945

@@ -8401,23 +8397,38 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
84018397
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
84028398
}
84038399

8400+
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8401+
const uint32_t ne = (uint32_t)node->ne[0];
8402+
const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
8403+
const uint32_t num_partials = CEIL_DIV(ne, denom);
8404+
return num_partials;
8405+
}
8406+
8407+
static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8408+
const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
8409+
const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
8410+
return num_bytes;
8411+
}
8412+
84048413
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
84058414
const uint32_t src0_type_size = ggml_type_size(src0->type);
84068415
const uint32_t src1_type_size = ggml_type_size(src1->type);
84078416
const uint32_t dst_type_size = ggml_type_size(dst->type);
84088417

8418+
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
8419+
84098420
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
84108421
(uint32_t)ggml_nelements(src0),
84118422
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
84128423
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
84138424
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
84148425
0,
8415-
op_params[0], 0.0f, ctx->do_add_rms_atomic,
8426+
op_params[0], 0.0f, (int32_t)param3,
84168427
}, dryrun);
84178428

8418-
if (ctx->do_add_rms_atomic) {
8419-
ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment;
8420-
ctx->do_add_rms_atomic = false;
8429+
if (ctx->do_add_rms_partials) {
8430+
ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
8431+
ctx->do_add_rms_partials = false;
84218432
}
84228433
}
84238434

@@ -9696,13 +9707,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
96969707
}
96979708
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
96989709
}
9699-
if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) {
9700-
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")");
9710+
if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
9711+
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
97019712
// Resize buffer
9702-
if (ctx->prealloc_atomic_add != nullptr) {
9703-
ggml_vk_destroy_buffer(ctx->prealloc_atomic_add);
9713+
if (ctx->prealloc_add_rms_partials != nullptr) {
9714+
ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
97049715
}
9705-
ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add);
9716+
ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
97069717
}
97079718
}
97089719

@@ -9766,9 +9777,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97669777
ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
97679778
ctx->device->add_rms_fusion) {
97689779
if (dryrun) {
9769-
ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment;
9780+
ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
97709781
}
9771-
ctx->do_add_rms_atomic = true;
9782+
ctx->do_add_rms_partials = true;
97729783
}
97739784
break;
97749785
case GGML_OP_REPEAT:
@@ -9892,7 +9903,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
98929903
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
98939904
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
98949905
if (node->op == GGML_OP_RMS_NORM) {
9895-
ctx->do_add_rms_atomic = false;
9906+
ctx->do_add_rms_partials = false;
98969907
}
98979908
return false;
98989909
}
@@ -10868,9 +10879,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1086810879
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
1086910880
}
1087010881

10871-
ctx->prealloc_size_atomic_add = 0;
10872-
ctx->prealloc_size_atomic_add_offset = 0;
10873-
ctx->do_add_rms_atomic = false;
10882+
ctx->prealloc_size_add_rms_partials = 0;
10883+
ctx->prealloc_size_add_rms_partials_offset = 0;
10884+
ctx->do_add_rms_partials = false;
1087410885

1087510886
uint64_t total_mat_mul_bytes = 0;
1087610887
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -10937,16 +10948,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1093710948
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
1093810949
}
1093910950

10940-
if (ctx->prealloc_size_atomic_add) {
10951+
if (ctx->prealloc_size_add_rms_partials) {
1094110952
if (ctx->compute_ctx.expired()) {
1094210953
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
1094310954
ctx->compute_ctx = compute_ctx;
1094410955
ggml_vk_ctx_begin(ctx->device, compute_ctx);
1094510956
} else {
1094610957
compute_ctx = ctx->compute_ctx.lock();
1094710958
}
10948-
// initialize atomic sums to zero.
10949-
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add);
10959+
// initialize partial sums to zero.
10960+
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
1095010961
}
1095110962

1095210963
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.

ggml/src/ggml-vulkan/vulkan-shaders/add.comp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#extension GL_EXT_shader_16bit_storage : require
44
#if ADD_RMS
5-
#extension GL_EXT_shader_atomic_float : enable
65
#extension GL_KHR_shader_subgroup_arithmetic : enable
76
#extension GL_KHR_shader_subgroup_basic : enable
87
#endif
@@ -12,12 +11,18 @@
1211

1312
const uint num_threads = 256;
1413

15-
layout (binding = 3) buffer AtomBuf {float data_atom;};
14+
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
1615

1716
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
1817

18+
#if ADD_RMS
19+
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
20+
shared FLOAT_TYPE sumsh[num_threads];
21+
#endif
22+
1923
void main() {
2024
uint idx = get_idx();
25+
uint orig_idx = idx;
2126

2227
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
2328
const uint num_iter = 2;
@@ -41,9 +46,23 @@ void main() {
4146

4247
#if ADD_RMS
4348
if (p.param3 != 0) {
49+
// reduce the sum within each subgroup, then across subgroups
50+
const uint NumSubgroups = num_threads / gl_SubgroupSize;
4451
sum_sq = subgroupAdd(sum_sq);
45-
if (sum_sq != 0 && gl_SubgroupInvocationID == 0) {
46-
atomicAdd(data_atom, sum_sq);
52+
if (gl_SubgroupInvocationID == 0) {
53+
sumsh[gl_SubgroupID] = sum_sq;
54+
}
55+
barrier();
56+
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
57+
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
58+
sum_sq += sumsh[gl_SubgroupID + s];
59+
sumsh[gl_SubgroupID] = sum_sq;
60+
}
61+
barrier();
62+
}
63+
64+
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
65+
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
4766
}
4867
}
4968
#endif

0 commit comments

Comments
 (0)