@@ -369,6 +369,10 @@ struct vk_device_struct {
369
369
bool subgroup_add;
370
370
bool subgroup_shuffle;
371
371
372
+ bool atomic_float_add;
373
+ bool add_rms_fusion;
374
+ uint32_t atomic_binding_alignment;
375
+
372
376
bool integer_dot_product;
373
377
374
378
bool subgroup_size_control;
@@ -448,6 +452,8 @@ struct vk_device_struct {
448
452
vk_pipeline pipeline_mul_norepeat[2][2][2];
449
453
vk_pipeline pipeline_div[2][2][2];
450
454
vk_pipeline pipeline_div_norepeat[2][2][2];
455
+ vk_pipeline pipeline_add_rms[2][2][2];
456
+ vk_pipeline pipeline_add_rms_norepeat[2][2][2];
451
457
452
458
vk_pipeline pipeline_add_id_f32;
453
459
@@ -1144,6 +1150,12 @@ class vk_perf_logger {
1144
1150
timings[name].push_back(time);
1145
1151
return;
1146
1152
}
1153
+ if (node->op == GGML_OP_RMS_NORM) {
1154
+ std::string name = ggml_op_name(node->op);
1155
+ name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
1156
+ timings[name].push_back(time);
1157
+ return;
1158
+ }
1147
1159
timings[ggml_op_name(node->op)].push_back(time);
1148
1160
}
1149
1161
private:
@@ -1158,10 +1170,13 @@ struct ggml_backend_vk_context {
1158
1170
1159
1171
size_t semaphore_idx, event_idx;
1160
1172
ggml_vk_garbage_collector gc;
1161
- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
1162
- vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
1173
+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset ;
1174
+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add ;
1163
1175
vk::Fence fence, almost_ready_fence;
1164
1176
bool almost_ready_fence_pending {};
1177
+ // Set before op_add and unset after op_rms_norm to indicate that the add should
1178
+ // use atomics to accumulate the square of the vector components
1179
+ bool do_add_rms_atomic;
1165
1180
1166
1181
vk_buffer buffer_pool[MAX_VK_BUFFERS];
1167
1182
@@ -2924,8 +2939,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2924
2939
2925
2940
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2926
2941
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2927
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2928
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2942
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true );
2943
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true );
2929
2944
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2930
2945
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2931
2946
@@ -2995,20 +3010,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
2995
3010
};
2996
3011
2997
3012
bool rte = device->float_controls_rte_fp16;
2998
- #define CREATE_BINARY(name, namemod, spec) \
3013
+ #define CREATE_BINARY(name, namemod, spec, bindings ) \
2999
3014
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
3000
3015
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
3001
3016
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
3002
- "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3003
-
3004
- CREATE_BINARY(add, , {0})
3005
- CREATE_BINARY(add, _norepeat, {1})
3006
- CREATE_BINARY(sub, , {0})
3007
- CREATE_BINARY(sub, _norepeat, {1})
3008
- CREATE_BINARY(mul, , {0})
3009
- CREATE_BINARY(mul, _norepeat, {1})
3010
- CREATE_BINARY(div, , {0})
3011
- CREATE_BINARY(div, _norepeat, {1})
3017
+ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3018
+
3019
+ CREATE_BINARY(add, , {0}, 4)
3020
+ CREATE_BINARY(add, _norepeat, {1}, 4)
3021
+ CREATE_BINARY(sub, , {0}, 3)
3022
+ CREATE_BINARY(sub, _norepeat, {1}, 3)
3023
+ CREATE_BINARY(mul, , {0}, 3)
3024
+ CREATE_BINARY(mul, _norepeat, {1}, 3)
3025
+ CREATE_BINARY(div, , {0}, 3)
3026
+ CREATE_BINARY(div, _norepeat, {1}, 3)
3027
+ CREATE_BINARY(add_rms, , {0}, 4)
3028
+ CREATE_BINARY(add_rms, _norepeat, {1}, 4)
3012
3029
#undef CREATE_BINARY
3013
3030
3014
3031
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
@@ -3281,6 +3298,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
3281
3298
device->coopmat_support = false;
3282
3299
device->integer_dot_product = false;
3283
3300
bool bfloat16_support = false;
3301
+ bool atomic_float_support = false;
3284
3302
3285
3303
for (const auto& properties : ext_props) {
3286
3304
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3320,6 +3338,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3320
3338
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
3321
3339
bfloat16_support = true;
3322
3340
#endif
3341
+ } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3342
+ atomic_float_support = true;
3323
3343
}
3324
3344
}
3325
3345
@@ -3536,6 +3556,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
3536
3556
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
3537
3557
}
3538
3558
3559
+ VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3560
+ atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3561
+ if (atomic_float_support) {
3562
+ last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3563
+ last_struct = (VkBaseOutStructure *)&atomic_float_features;
3564
+ device_extensions.push_back("VK_EXT_shader_atomic_float");
3565
+ }
3566
+
3539
3567
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
3540
3568
3541
3569
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3547,6 +3575,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
3547
3575
#endif
3548
3576
3549
3577
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3578
+ device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
3550
3579
3551
3580
if (device->subgroup_size_control) {
3552
3581
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -3861,6 +3890,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
3861
3890
3862
3891
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3863
3892
3893
+ device->add_rms_fusion = !device->disable_fusion &&
3894
+ device->subgroup_add &&
3895
+ device->atomic_float_add;
3896
+ device->atomic_binding_alignment =
3897
+ std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
3898
+
3864
3899
return device;
3865
3900
}
3866
3901
@@ -6892,8 +6927,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6892
6927
switch (op) {
6893
6928
case GGML_OP_ADD:
6894
6929
{
6895
- auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6896
- return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6930
+ if (ctx->do_add_rms_atomic) {
6931
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
6932
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6933
+ } else {
6934
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6935
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6936
+ }
6897
6937
}
6898
6938
case GGML_OP_SUB:
6899
6939
{
@@ -7494,7 +7534,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7494
7534
}
7495
7535
} break;
7496
7536
case GGML_OP_RMS_NORM:
7497
- elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7537
+ if (ctx->do_add_rms_atomic) {
7538
+ // Run one element per thread, 128 threads per workgroup
7539
+ elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7540
+ } else {
7541
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7542
+ }
7498
7543
break;
7499
7544
7500
7545
case GGML_OP_SUM:
@@ -7642,7 +7687,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7642
7687
}
7643
7688
}
7644
7689
7645
- if (op == GGML_OP_GLU) {
7690
+ if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7691
+ vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7692
+ size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7693
+ ggml_vk_sync_buffers(subctx);
7694
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7695
+ { vk_subbuffer{ d_X, x_buf_offset, x_sz },
7696
+ vk_subbuffer{ d_Y, y_buf_offset, y_sz },
7697
+ vk_subbuffer{ d_D, d_buf_offset, d_sz },
7698
+ vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
7699
+ }, pc, elements);
7700
+ } else if (op == GGML_OP_GLU) {
7646
7701
// Empty src1 is possible in glu, but the shader needs a buffer
7647
7702
vk_subbuffer subbuf_y;
7648
7703
if (use_src1) {
@@ -7750,7 +7805,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
7750
7805
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7751
7806
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7752
7807
0,
7753
- 0.0f, 0.0f, 0 ,
7808
+ 0.0f, 0.0f, ctx->do_add_rms_atomic ,
7754
7809
}, dryrun);
7755
7810
}
7756
7811
@@ -8213,8 +8268,13 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
8213
8268
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
8214
8269
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
8215
8270
0,
8216
- op_params[0], 0.0f, 0 ,
8271
+ op_params[0], 0.0f, ctx->do_add_rms_atomic ,
8217
8272
}, dryrun);
8273
+
8274
+ if (ctx->do_add_rms_atomic) {
8275
+ ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment;
8276
+ ctx->do_add_rms_atomic = false;
8277
+ }
8218
8278
}
8219
8279
8220
8280
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9492,6 +9552,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
9492
9552
}
9493
9553
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
9494
9554
}
9555
+ if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) {
9556
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")");
9557
+ // Resize buffer
9558
+ if (ctx->prealloc_atomic_add != nullptr) {
9559
+ ggml_vk_destroy_buffer(ctx->prealloc_atomic_add);
9560
+ }
9561
+ ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add);
9562
+ }
9495
9563
}
9496
9564
9497
9565
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9547,10 +9615,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9547
9615
return false;
9548
9616
}
9549
9617
break;
9618
+ case GGML_OP_ADD:
9619
+ if (node_idx + 1 < cgraph->n_nodes &&
9620
+ cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
9621
+ cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
9622
+ ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9623
+ ctx->device->add_rms_fusion) {
9624
+ if (dryrun) {
9625
+ ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment;
9626
+ }
9627
+ ctx->do_add_rms_atomic = true;
9628
+ }
9629
+ break;
9550
9630
case GGML_OP_REPEAT:
9551
9631
case GGML_OP_REPEAT_BACK:
9552
9632
case GGML_OP_GET_ROWS:
9553
- case GGML_OP_ADD:
9554
9633
case GGML_OP_ADD_ID:
9555
9634
case GGML_OP_ACC:
9556
9635
case GGML_OP_SUB:
@@ -9667,6 +9746,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9667
9746
// do the only thing needed for the dryrun.
9668
9747
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
9669
9748
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9749
+ if (node->op == GGML_OP_RMS_NORM) {
9750
+ ctx->do_add_rms_atomic = false;
9751
+ }
9670
9752
return false;
9671
9753
}
9672
9754
default:
@@ -10581,6 +10663,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10581
10663
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
10582
10664
}
10583
10665
10666
+ ctx->prealloc_size_atomic_add = 0;
10667
+ ctx->prealloc_size_atomic_add_offset = 0;
10668
+ ctx->do_add_rms_atomic = false;
10669
+
10584
10670
uint64_t total_mat_mul_bytes = 0;
10585
10671
for (int i = 0; i < cgraph->n_nodes; i++) {
10586
10672
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
@@ -10641,6 +10727,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10641
10727
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
10642
10728
}
10643
10729
10730
+ if (ctx->prealloc_size_atomic_add) {
10731
+ if (ctx->compute_ctx.expired()) {
10732
+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10733
+ ctx->compute_ctx = compute_ctx;
10734
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
10735
+ } else {
10736
+ compute_ctx = ctx->compute_ctx.lock();
10737
+ }
10738
+ // initialize atomic sums to zero.
10739
+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add);
10740
+ }
10741
+
10644
10742
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
10645
10743
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
10646
10744
// (and scaled down based on model size, so smaller models submit earlier).
0 commit comments