@@ -369,9 +369,8 @@ struct vk_device_struct {
369
369
bool subgroup_add;
370
370
bool subgroup_shuffle;
371
371
372
- bool atomic_float_add;
373
372
bool add_rms_fusion;
374
- uint32_t atomic_binding_alignment ;
373
+ uint32_t partials_binding_alignment ;
375
374
376
375
bool integer_dot_product;
377
376
@@ -476,6 +475,8 @@ struct vk_device_struct {
476
475
vk_pipeline pipeline_group_norm_f32;
477
476
vk_pipeline pipeline_rms_norm_f32;
478
477
vk_pipeline pipeline_rms_norm_mul_f32;
478
+ vk_pipeline pipeline_rms_norm_partials_f32;
479
+ vk_pipeline pipeline_rms_norm_mul_partials_f32;
479
480
vk_pipeline pipeline_rms_norm_back_f32;
480
481
vk_pipeline pipeline_l2_norm_f32;
481
482
@@ -1170,13 +1171,13 @@ struct ggml_backend_vk_context {
1170
1171
1171
1172
size_t semaphore_idx, event_idx;
1172
1173
ggml_vk_garbage_collector gc;
1173
- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset ;
1174
- vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add ;
1174
+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset ;
1175
+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials ;
1175
1176
vk::Fence fence, almost_ready_fence;
1176
1177
bool almost_ready_fence_pending {};
1177
1178
// Set before op_add and unset after op_rms_norm to indicate that the add should
1178
- // use atomics to accumulate the square of the vector components
1179
- bool do_add_rms_atomic ;
1179
+ // write partial sums to accumulate the square of the vector components
1180
+ bool do_add_rms_partials ;
1180
1181
1181
1182
vk_buffer buffer_pool[MAX_VK_BUFFERS];
1182
1183
@@ -2939,8 +2940,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
2939
2940
2940
2941
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2941
2942
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2943
+
2942
2944
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2943
2945
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2946
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2947
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2948
+
2944
2949
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2945
2950
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2946
2951
@@ -3298,7 +3303,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
3298
3303
device->coopmat_support = false;
3299
3304
device->integer_dot_product = false;
3300
3305
bool bfloat16_support = false;
3301
- bool atomic_float_support = false;
3302
3306
3303
3307
for (const auto& properties : ext_props) {
3304
3308
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3338,8 +3342,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
3338
3342
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
3339
3343
bfloat16_support = true;
3340
3344
#endif
3341
- } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3342
- atomic_float_support = true;
3343
3345
}
3344
3346
}
3345
3347
@@ -3556,14 +3558,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
3556
3558
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
3557
3559
}
3558
3560
3559
- VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3560
- atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3561
- if (atomic_float_support) {
3562
- last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3563
- last_struct = (VkBaseOutStructure *)&atomic_float_features;
3564
- device_extensions.push_back("VK_EXT_shader_atomic_float");
3565
- }
3566
-
3567
3561
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
3568
3562
3569
3563
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3575,7 +3569,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
3575
3569
#endif
3576
3570
3577
3571
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3578
- device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
3579
3572
3580
3573
if (device->subgroup_size_control) {
3581
3574
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -3891,9 +3884,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3891
3884
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3892
3885
3893
3886
device->add_rms_fusion = !device->disable_fusion &&
3894
- device->subgroup_add &&
3895
- device->atomic_float_add;
3896
- device->atomic_binding_alignment =
3887
+ device->subgroup_add;
3888
+ device->partials_binding_alignment =
3897
3889
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
3898
3890
3899
3891
return device;
@@ -6927,7 +6919,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6927
6919
switch (op) {
6928
6920
case GGML_OP_ADD:
6929
6921
{
6930
- if (ctx->do_add_rms_atomic ) {
6922
+ if (ctx->do_add_rms_partials ) {
6931
6923
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
6932
6924
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6933
6925
} else {
@@ -7051,7 +7043,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7051
7043
return nullptr;
7052
7044
case GGML_OP_RMS_NORM:
7053
7045
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7054
- return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7046
+ if (ctx->do_add_rms_partials) {
7047
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
7048
+ } else {
7049
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7050
+ }
7055
7051
}
7056
7052
return nullptr;
7057
7053
case GGML_OP_RMS_NORM_BACK:
@@ -7534,7 +7530,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7534
7530
}
7535
7531
} break;
7536
7532
case GGML_OP_RMS_NORM:
7537
- if (ctx->do_add_rms_atomic ) {
7533
+ if (ctx->do_add_rms_partials ) {
7538
7534
// Run one element per thread, 128 threads per workgroup
7539
7535
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7540
7536
} else {
@@ -7688,8 +7684,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7688
7684
}
7689
7685
7690
7686
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7691
- vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7692
- size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7687
+ vk_buffer d_A = ctx->prealloc_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
7688
+ size_t a_buf_offset = ctx->prealloc_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
7693
7689
ggml_vk_sync_buffers(subctx);
7694
7690
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7695
7691
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
@@ -7805,7 +7801,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
7805
7801
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7806
7802
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7807
7803
0,
7808
- 0.0f, 0.0f, ctx->do_add_rms_atomic ,
7804
+ 0.0f, 0.0f, ctx->do_add_rms_partials ,
7809
7805
}, dryrun);
7810
7806
}
7811
7807
@@ -8257,23 +8253,38 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
8257
8253
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
8258
8254
}
8259
8255
8256
+ static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8257
+ const uint32_t ne = (uint32_t)node->ne[0];
8258
+ const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
8259
+ const uint32_t num_partials = CEIL_DIV(ne, denom);
8260
+ return num_partials;
8261
+ }
8262
+
8263
+ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8264
+ const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
8265
+ const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
8266
+ return num_bytes;
8267
+ }
8268
+
8260
8269
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
8261
8270
const uint32_t src0_type_size = ggml_type_size(src0->type);
8262
8271
const uint32_t src1_type_size = ggml_type_size(src1->type);
8263
8272
const uint32_t dst_type_size = ggml_type_size(dst->type);
8264
8273
8274
+ uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
8275
+
8265
8276
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
8266
8277
(uint32_t)ggml_nelements(src0),
8267
8278
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
8268
8279
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
8269
8280
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
8270
8281
0,
8271
- op_params[0], 0.0f, ctx->do_add_rms_atomic ,
8282
+ op_params[0], 0.0f, (int32_t)param3 ,
8272
8283
}, dryrun);
8273
8284
8274
- if (ctx->do_add_rms_atomic ) {
8275
- ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment ;
8276
- ctx->do_add_rms_atomic = false;
8285
+ if (ctx->do_add_rms_partials ) {
8286
+ ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size( ctx, src0) ;
8287
+ ctx->do_add_rms_partials = false;
8277
8288
}
8278
8289
}
8279
8290
@@ -9552,13 +9563,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
9552
9563
}
9553
9564
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
9554
9565
}
9555
- if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add ->size < ctx->prealloc_size_atomic_add )) {
9556
- VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size : " << ctx->prealloc_atomic_add << ")");
9566
+ if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials ->size < ctx->prealloc_size_add_rms_partials )) {
9567
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size : " << ctx->prealloc_add_rms_partials << ")");
9557
9568
// Resize buffer
9558
- if (ctx->prealloc_atomic_add != nullptr) {
9559
- ggml_vk_destroy_buffer(ctx->prealloc_atomic_add );
9569
+ if (ctx->prealloc_add_rms_partials != nullptr) {
9570
+ ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials );
9560
9571
}
9561
- ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add );
9572
+ ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials );
9562
9573
}
9563
9574
}
9564
9575
@@ -9622,9 +9633,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9622
9633
ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9623
9634
ctx->device->add_rms_fusion) {
9624
9635
if (dryrun) {
9625
- ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment ;
9636
+ ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size( ctx, cgraph->nodes[node_idx]) ;
9626
9637
}
9627
- ctx->do_add_rms_atomic = true;
9638
+ ctx->do_add_rms_partials = true;
9628
9639
}
9629
9640
break;
9630
9641
case GGML_OP_REPEAT:
@@ -9747,7 +9758,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9747
9758
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
9748
9759
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9749
9760
if (node->op == GGML_OP_RMS_NORM) {
9750
- ctx->do_add_rms_atomic = false;
9761
+ ctx->do_add_rms_partials = false;
9751
9762
}
9752
9763
return false;
9753
9764
}
@@ -10663,9 +10674,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10663
10674
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
10664
10675
}
10665
10676
10666
- ctx->prealloc_size_atomic_add = 0;
10667
- ctx->prealloc_size_atomic_add_offset = 0;
10668
- ctx->do_add_rms_atomic = false;
10677
+ ctx->prealloc_size_add_rms_partials = 0;
10678
+ ctx->prealloc_size_add_rms_partials_offset = 0;
10679
+ ctx->do_add_rms_partials = false;
10669
10680
10670
10681
uint64_t total_mat_mul_bytes = 0;
10671
10682
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -10727,16 +10738,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10727
10738
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
10728
10739
}
10729
10740
10730
- if (ctx->prealloc_size_atomic_add ) {
10741
+ if (ctx->prealloc_size_add_rms_partials ) {
10731
10742
if (ctx->compute_ctx.expired()) {
10732
10743
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10733
10744
ctx->compute_ctx = compute_ctx;
10734
10745
ggml_vk_ctx_begin(ctx->device, compute_ctx);
10735
10746
} else {
10736
10747
compute_ctx = ctx->compute_ctx.lock();
10737
10748
}
10738
- // initialize atomic sums to zero.
10739
- ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add , 0, 0, ctx->prealloc_size_atomic_add );
10749
+ // initialize partial sums to zero.
10750
+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials , 0, 0, ctx->prealloc_size_add_rms_partials );
10740
10751
}
10741
10752
10742
10753
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
0 commit comments