@@ -372,9 +372,8 @@ struct vk_device_struct {
372
372
bool subgroup_shuffle;
373
373
bool multi_add;
374
374
375
- bool atomic_float_add;
376
375
bool add_rms_fusion;
377
- uint32_t atomic_binding_alignment ;
376
+ uint32_t partials_binding_alignment ;
378
377
379
378
bool integer_dot_product;
380
379
@@ -482,6 +481,8 @@ struct vk_device_struct {
482
481
vk_pipeline pipeline_group_norm_f32;
483
482
vk_pipeline pipeline_rms_norm_f32;
484
483
vk_pipeline pipeline_rms_norm_mul_f32;
484
+ vk_pipeline pipeline_rms_norm_partials_f32;
485
+ vk_pipeline pipeline_rms_norm_mul_partials_f32;
485
486
vk_pipeline pipeline_rms_norm_back_f32;
486
487
vk_pipeline pipeline_l2_norm_f32;
487
488
@@ -1191,13 +1192,13 @@ struct ggml_backend_vk_context {
1191
1192
1192
1193
size_t semaphore_idx, event_idx;
1193
1194
ggml_vk_garbage_collector gc;
1194
- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset ;
1195
- vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add ;
1195
+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset ;
1196
+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials ;
1196
1197
vk::Fence fence, almost_ready_fence;
1197
1198
bool almost_ready_fence_pending {};
1198
1199
// Set before op_add and unset after op_rms_norm to indicate that the add should
1199
- // use atomics to accumulate the square of the vector components
1200
- bool do_add_rms_atomic ;
1200
+ // write partial sums to accumulate the square of the vector components
1201
+ bool do_add_rms_partials ;
1201
1202
1202
1203
vk_buffer buffer_pool[MAX_VK_BUFFERS];
1203
1204
@@ -2936,8 +2937,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
2936
2937
2937
2938
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2938
2939
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2940
+
2939
2941
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2940
2942
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2943
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2944
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2945
+
2941
2946
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2942
2947
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2943
2948
@@ -3303,7 +3308,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
3303
3308
device->coopmat_support = false;
3304
3309
device->integer_dot_product = false;
3305
3310
bool bfloat16_support = false;
3306
- bool atomic_float_support = false;
3307
3311
3308
3312
for (const auto& properties : ext_props) {
3309
3313
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3343,8 +3347,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
3343
3347
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
3344
3348
bfloat16_support = true;
3345
3349
#endif
3346
- } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3347
- atomic_float_support = true;
3348
3350
}
3349
3351
}
3350
3352
@@ -3561,14 +3563,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
3561
3563
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
3562
3564
}
3563
3565
3564
- VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3565
- atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3566
- if (atomic_float_support) {
3567
- last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3568
- last_struct = (VkBaseOutStructure *)&atomic_float_features;
3569
- device_extensions.push_back("VK_EXT_shader_atomic_float");
3570
- }
3571
-
3572
3566
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
3573
3567
3574
3568
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3580,7 +3574,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
3580
3574
#endif
3581
3575
3582
3576
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3583
- device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
3584
3577
3585
3578
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
3586
3579
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
@@ -3902,9 +3895,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3902
3895
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3903
3896
3904
3897
device->add_rms_fusion = !device->disable_fusion &&
3905
- device->subgroup_add &&
3906
- device->atomic_float_add;
3907
- device->atomic_binding_alignment =
3898
+ device->subgroup_add;
3899
+ device->partials_binding_alignment =
3908
3900
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
3909
3901
3910
3902
return device;
@@ -6949,13 +6941,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6949
6941
case GGML_OP_ADD:
6950
6942
{
6951
6943
if (ctx->num_additional_fused_ops > 0) {
6952
- if (ctx->do_add_rms_atomic ) {
6944
+ if (ctx->do_add_rms_partials ) {
6953
6945
return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
6954
6946
} else {
6955
6947
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6956
6948
}
6957
6949
}
6958
- if (ctx->do_add_rms_atomic ) {
6950
+ if (ctx->do_add_rms_partials ) {
6959
6951
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
6960
6952
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6961
6953
} else {
@@ -7079,7 +7071,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7079
7071
return nullptr;
7080
7072
case GGML_OP_RMS_NORM:
7081
7073
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7082
- return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7074
+ if (ctx->do_add_rms_partials) {
7075
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
7076
+ } else {
7077
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7078
+ }
7083
7079
}
7084
7080
return nullptr;
7085
7081
case GGML_OP_RMS_NORM_BACK:
@@ -7567,7 +7563,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7567
7563
}
7568
7564
} break;
7569
7565
case GGML_OP_RMS_NORM:
7570
- if (ctx->do_add_rms_atomic ) {
7566
+ if (ctx->do_add_rms_partials ) {
7571
7567
// Run one element per thread, 128 threads per workgroup
7572
7568
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7573
7569
} else {
@@ -7721,8 +7717,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7721
7717
}
7722
7718
7723
7719
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7724
- vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7725
- size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7720
+ vk_buffer d_A = ctx->prealloc_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
7721
+ size_t a_buf_offset = ctx->prealloc_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
7726
7722
ggml_vk_sync_buffers(subctx);
7727
7723
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7728
7724
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
@@ -7943,7 +7939,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
7943
7939
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7944
7940
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7945
7941
0,
7946
- 0.0f, 0.0f, ctx->do_add_rms_atomic ,
7942
+ 0.0f, 0.0f, ctx->do_add_rms_partials ,
7947
7943
}, dryrun);
7948
7944
}
7949
7945
@@ -8401,23 +8397,38 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
8401
8397
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
8402
8398
}
8403
8399
8400
+ static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8401
+ const uint32_t ne = (uint32_t)node->ne[0];
8402
+ const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
8403
+ const uint32_t num_partials = CEIL_DIV(ne, denom);
8404
+ return num_partials;
8405
+ }
8406
+
8407
+ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8408
+ const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
8409
+ const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
8410
+ return num_bytes;
8411
+ }
8412
+
8404
8413
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
8405
8414
const uint32_t src0_type_size = ggml_type_size(src0->type);
8406
8415
const uint32_t src1_type_size = ggml_type_size(src1->type);
8407
8416
const uint32_t dst_type_size = ggml_type_size(dst->type);
8408
8417
8418
+ uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
8419
+
8409
8420
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
8410
8421
(uint32_t)ggml_nelements(src0),
8411
8422
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
8412
8423
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
8413
8424
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
8414
8425
0,
8415
- op_params[0], 0.0f, ctx->do_add_rms_atomic ,
8426
+ op_params[0], 0.0f, (int32_t)param3 ,
8416
8427
}, dryrun);
8417
8428
8418
- if (ctx->do_add_rms_atomic ) {
8419
- ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment ;
8420
- ctx->do_add_rms_atomic = false;
8429
+ if (ctx->do_add_rms_partials ) {
8430
+ ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size( ctx, src0) ;
8431
+ ctx->do_add_rms_partials = false;
8421
8432
}
8422
8433
}
8423
8434
@@ -9696,13 +9707,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
9696
9707
}
9697
9708
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
9698
9709
}
9699
- if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add ->size < ctx->prealloc_size_atomic_add )) {
9700
- VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size : " << ctx->prealloc_atomic_add << ")");
9710
+ if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials ->size < ctx->prealloc_size_add_rms_partials )) {
9711
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size : " << ctx->prealloc_add_rms_partials << ")");
9701
9712
// Resize buffer
9702
- if (ctx->prealloc_atomic_add != nullptr) {
9703
- ggml_vk_destroy_buffer(ctx->prealloc_atomic_add );
9713
+ if (ctx->prealloc_add_rms_partials != nullptr) {
9714
+ ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials );
9704
9715
}
9705
- ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add );
9716
+ ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials );
9706
9717
}
9707
9718
}
9708
9719
@@ -9766,9 +9777,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9766
9777
ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9767
9778
ctx->device->add_rms_fusion) {
9768
9779
if (dryrun) {
9769
- ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment ;
9780
+ ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size( ctx, cgraph->nodes[node_idx]) ;
9770
9781
}
9771
- ctx->do_add_rms_atomic = true;
9782
+ ctx->do_add_rms_partials = true;
9772
9783
}
9773
9784
break;
9774
9785
case GGML_OP_REPEAT:
@@ -9892,7 +9903,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9892
9903
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
9893
9904
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9894
9905
if (node->op == GGML_OP_RMS_NORM) {
9895
- ctx->do_add_rms_atomic = false;
9906
+ ctx->do_add_rms_partials = false;
9896
9907
}
9897
9908
return false;
9898
9909
}
@@ -10868,9 +10879,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10868
10879
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
10869
10880
}
10870
10881
10871
- ctx->prealloc_size_atomic_add = 0;
10872
- ctx->prealloc_size_atomic_add_offset = 0;
10873
- ctx->do_add_rms_atomic = false;
10882
+ ctx->prealloc_size_add_rms_partials = 0;
10883
+ ctx->prealloc_size_add_rms_partials_offset = 0;
10884
+ ctx->do_add_rms_partials = false;
10874
10885
10875
10886
uint64_t total_mat_mul_bytes = 0;
10876
10887
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -10937,16 +10948,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10937
10948
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
10938
10949
}
10939
10950
10940
- if (ctx->prealloc_size_atomic_add ) {
10951
+ if (ctx->prealloc_size_add_rms_partials ) {
10941
10952
if (ctx->compute_ctx.expired()) {
10942
10953
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10943
10954
ctx->compute_ctx = compute_ctx;
10944
10955
ggml_vk_ctx_begin(ctx->device, compute_ctx);
10945
10956
} else {
10946
10957
compute_ctx = ctx->compute_ctx.lock();
10947
10958
}
10948
- // initialize atomic sums to zero.
10949
- ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add , 0, 0, ctx->prealloc_size_atomic_add );
10959
+ // initialize partial sums to zero.
10960
+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials , 0, 0, ctx->prealloc_size_add_rms_partials );
10950
10961
}
10951
10962
10952
10963
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
0 commit comments