Skip to content

Commit 075dac2

Browse files
committed
Change add+rms_norm optimization to write out an array of partial sums
rather than using atomic add, to make it deterministic. The rms_norm shader fetches a subgroup's worth in parallel and uses subgroupAdd to add them up.
1 parent 97ec47a commit 075dac2

File tree

6 files changed

+187
-124
lines changed

6 files changed

+187
-124
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,8 @@ struct vk_device_struct {
369369
bool subgroup_add;
370370
bool subgroup_shuffle;
371371

372-
bool atomic_float_add;
373372
bool add_rms_fusion;
374-
uint32_t atomic_binding_alignment;
373+
uint32_t partials_binding_alignment;
375374

376375
bool integer_dot_product;
377376

@@ -476,6 +475,8 @@ struct vk_device_struct {
476475
vk_pipeline pipeline_group_norm_f32;
477476
vk_pipeline pipeline_rms_norm_f32;
478477
vk_pipeline pipeline_rms_norm_mul_f32;
478+
vk_pipeline pipeline_rms_norm_partials_f32;
479+
vk_pipeline pipeline_rms_norm_mul_partials_f32;
479480
vk_pipeline pipeline_rms_norm_back_f32;
480481
vk_pipeline pipeline_l2_norm_f32;
481482

@@ -1170,13 +1171,13 @@ struct ggml_backend_vk_context {
11701171

11711172
size_t semaphore_idx, event_idx;
11721173
ggml_vk_garbage_collector gc;
1173-
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset;
1174-
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add;
1174+
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
1175+
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials;
11751176
vk::Fence fence, almost_ready_fence;
11761177
bool almost_ready_fence_pending {};
11771178
// Set before op_add and unset after op_rms_norm to indicate that the add should
1178-
// use atomics to accumulate the square of the vector components
1179-
bool do_add_rms_atomic;
1179+
// write partial sums to accumulate the square of the vector components
1180+
bool do_add_rms_partials;
11801181

11811182
vk_buffer buffer_pool[MAX_VK_BUFFERS];
11821183

@@ -2939,8 +2940,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
29392940

29402941
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29412942
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2943+
29422944
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
29432945
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2946+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2947+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2948+
29442949
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29452950
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29462951

@@ -3298,7 +3303,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
32983303
device->coopmat_support = false;
32993304
device->integer_dot_product = false;
33003305
bool bfloat16_support = false;
3301-
bool atomic_float_support = false;
33023306

33033307
for (const auto& properties : ext_props) {
33043308
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3338,8 +3342,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
33383342
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
33393343
bfloat16_support = true;
33403344
#endif
3341-
} else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3342-
atomic_float_support = true;
33433345
}
33443346
}
33453347

@@ -3556,14 +3558,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
35563558
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
35573559
}
35583560

3559-
VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3560-
atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3561-
if (atomic_float_support) {
3562-
last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3563-
last_struct = (VkBaseOutStructure *)&atomic_float_features;
3564-
device_extensions.push_back("VK_EXT_shader_atomic_float");
3565-
}
3566-
35673561
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
35683562

35693563
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3575,7 +3569,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
35753569
#endif
35763570

35773571
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3578-
device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
35793572

35803573
if (device->subgroup_size_control) {
35813574
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -3891,9 +3884,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
38913884
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
38923885

38933886
device->add_rms_fusion = !device->disable_fusion &&
3894-
device->subgroup_add &&
3895-
device->atomic_float_add;
3896-
device->atomic_binding_alignment =
3887+
device->subgroup_add;
3888+
device->partials_binding_alignment =
38973889
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
38983890

38993891
return device;
@@ -6927,7 +6919,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69276919
switch (op) {
69286920
case GGML_OP_ADD:
69296921
{
6930-
if (ctx->do_add_rms_atomic) {
6922+
if (ctx->do_add_rms_partials) {
69316923
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
69326924
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
69336925
} else {
@@ -7051,7 +7043,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
70517043
return nullptr;
70527044
case GGML_OP_RMS_NORM:
70537045
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7054-
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7046+
if (ctx->do_add_rms_partials) {
7047+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
7048+
} else {
7049+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7050+
}
70557051
}
70567052
return nullptr;
70577053
case GGML_OP_RMS_NORM_BACK:
@@ -7534,7 +7530,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
75347530
}
75357531
} break;
75367532
case GGML_OP_RMS_NORM:
7537-
if (ctx->do_add_rms_atomic) {
7533+
if (ctx->do_add_rms_partials) {
75387534
// Run one element per thread, 128 threads per workgroup
75397535
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
75407536
} else {
@@ -7688,8 +7684,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
76887684
}
76897685

76907686
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7691-
vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7692-
size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7687+
vk_buffer d_A = ctx->prealloc_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
7688+
size_t a_buf_offset = ctx->prealloc_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
76937689
ggml_vk_sync_buffers(subctx);
76947690
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
76957691
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
@@ -7805,7 +7801,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
78057801
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
78067802
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
78077803
0,
7808-
0.0f, 0.0f, ctx->do_add_rms_atomic,
7804+
0.0f, 0.0f, ctx->do_add_rms_partials,
78097805
}, dryrun);
78107806
}
78117807

@@ -8257,23 +8253,38 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
82578253
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
82588254
}
82598255

8256+
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8257+
const uint32_t ne = (uint32_t)node->ne[0];
8258+
const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
8259+
const uint32_t num_partials = CEIL_DIV(ne, denom);
8260+
return num_partials;
8261+
}
8262+
8263+
static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8264+
const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
8265+
const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
8266+
return num_bytes;
8267+
}
8268+
82608269
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
82618270
const uint32_t src0_type_size = ggml_type_size(src0->type);
82628271
const uint32_t src1_type_size = ggml_type_size(src1->type);
82638272
const uint32_t dst_type_size = ggml_type_size(dst->type);
82648273

8274+
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
8275+
82658276
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
82668277
(uint32_t)ggml_nelements(src0),
82678278
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
82688279
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
82698280
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
82708281
0,
8271-
op_params[0], 0.0f, ctx->do_add_rms_atomic,
8282+
op_params[0], 0.0f, (int32_t)param3,
82728283
}, dryrun);
82738284

8274-
if (ctx->do_add_rms_atomic) {
8275-
ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment;
8276-
ctx->do_add_rms_atomic = false;
8285+
if (ctx->do_add_rms_partials) {
8286+
ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
8287+
ctx->do_add_rms_partials = false;
82778288
}
82788289
}
82798290

@@ -9552,13 +9563,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
95529563
}
95539564
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
95549565
}
9555-
if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) {
9556-
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")");
9566+
if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
9567+
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
95579568
// Resize buffer
9558-
if (ctx->prealloc_atomic_add != nullptr) {
9559-
ggml_vk_destroy_buffer(ctx->prealloc_atomic_add);
9569+
if (ctx->prealloc_add_rms_partials != nullptr) {
9570+
ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
95609571
}
9561-
ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add);
9572+
ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
95629573
}
95639574
}
95649575

@@ -9622,9 +9633,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96229633
ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
96239634
ctx->device->add_rms_fusion) {
96249635
if (dryrun) {
9625-
ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment;
9636+
ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
96269637
}
9627-
ctx->do_add_rms_atomic = true;
9638+
ctx->do_add_rms_partials = true;
96289639
}
96299640
break;
96309641
case GGML_OP_REPEAT:
@@ -9747,7 +9758,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97479758
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
97489759
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
97499760
if (node->op == GGML_OP_RMS_NORM) {
9750-
ctx->do_add_rms_atomic = false;
9761+
ctx->do_add_rms_partials = false;
97519762
}
97529763
return false;
97539764
}
@@ -10663,9 +10674,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1066310674
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
1066410675
}
1066510676

10666-
ctx->prealloc_size_atomic_add = 0;
10667-
ctx->prealloc_size_atomic_add_offset = 0;
10668-
ctx->do_add_rms_atomic = false;
10677+
ctx->prealloc_size_add_rms_partials = 0;
10678+
ctx->prealloc_size_add_rms_partials_offset = 0;
10679+
ctx->do_add_rms_partials = false;
1066910680

1067010681
uint64_t total_mat_mul_bytes = 0;
1067110682
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -10727,16 +10738,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1072710738
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
1072810739
}
1072910740

10730-
if (ctx->prealloc_size_atomic_add) {
10741+
if (ctx->prealloc_size_add_rms_partials) {
1073110742
if (ctx->compute_ctx.expired()) {
1073210743
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
1073310744
ctx->compute_ctx = compute_ctx;
1073410745
ggml_vk_ctx_begin(ctx->device, compute_ctx);
1073510746
} else {
1073610747
compute_ctx = ctx->compute_ctx.lock();
1073710748
}
10738-
// initialize atomic sums to zero.
10739-
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add);
10749+
// initialize partial sums to zero.
10750+
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
1074010751
}
1074110752

1074210753
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.

ggml/src/ggml-vulkan/vulkan-shaders/add.comp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#extension GL_EXT_shader_16bit_storage : require
44
#if ADD_RMS
5-
#extension GL_EXT_shader_atomic_float : enable
65
#extension GL_KHR_shader_subgroup_arithmetic : enable
76
#extension GL_KHR_shader_subgroup_basic : enable
87
#endif
@@ -12,12 +11,18 @@
1211

1312
const uint num_threads = 256;
1413

15-
layout (binding = 3) buffer AtomBuf {float data_atom;};
14+
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
1615

1716
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
1817

18+
#if ADD_RMS
19+
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
20+
shared FLOAT_TYPE sumsh[num_threads];
21+
#endif
22+
1923
void main() {
2024
uint idx = get_idx();
25+
uint orig_idx = idx;
2126

2227
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
2328
const uint num_iter = 2;
@@ -41,9 +46,23 @@ void main() {
4146

4247
#if ADD_RMS
4348
if (p.param3 != 0) {
49+
// reduce the sum within each subgroup, then across subgroups
50+
const uint NumSubgroups = num_threads / gl_SubgroupSize;
4451
sum_sq = subgroupAdd(sum_sq);
45-
if (sum_sq != 0 && gl_SubgroupInvocationID == 0) {
46-
atomicAdd(data_atom, sum_sq);
52+
if (gl_SubgroupInvocationID == 0) {
53+
sumsh[gl_SubgroupID] = sum_sq;
54+
}
55+
barrier();
56+
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
57+
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
58+
sum_sq += sumsh[gl_SubgroupID + s];
59+
sumsh[gl_SubgroupID] = sum_sq;
60+
}
61+
barrier();
62+
}
63+
64+
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
65+
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
4766
}
4867
}
4968
#endif

0 commit comments

Comments
 (0)