@@ -372,6 +372,10 @@ struct vk_device_struct {
372372 bool subgroup_shuffle;
373373 bool multi_add;
374374
375+ bool atomic_float_add;
376+ bool add_rms_fusion;
377+ uint32_t atomic_binding_alignment;
378+
375379 bool integer_dot_product;
376380
377381 bool subgroup_size_control;
@@ -451,6 +455,8 @@ struct vk_device_struct {
451455 vk_pipeline pipeline_mul_norepeat[2][2][2];
452456 vk_pipeline pipeline_div[2][2][2];
453457 vk_pipeline pipeline_div_norepeat[2][2][2];
458+ vk_pipeline pipeline_add_rms[2][2][2];
459+ vk_pipeline pipeline_add_rms_norepeat[2][2][2];
454460
455461 // indexed by num_additional_fused_ops == num_adds - 1
456462 vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
@@ -1165,6 +1171,12 @@ class vk_perf_logger {
11651171 timings[name].push_back(time);
11661172 return;
11671173 }
1174+ if (node->op == GGML_OP_RMS_NORM) {
1175+ std::string name = ggml_op_name(node->op);
1176+ name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
1177+ timings[name].push_back(time);
1178+ return;
1179+ }
11681180 timings[ggml_op_name(node->op)].push_back(time);
11691181 }
11701182 private:
@@ -1179,10 +1191,13 @@ struct ggml_backend_vk_context {
11791191
11801192 size_t semaphore_idx, event_idx;
11811193 ggml_vk_garbage_collector gc;
1182- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
1183- vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
1194+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset ;
1195+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add ;
11841196 vk::Fence fence, almost_ready_fence;
11851197 bool almost_ready_fence_pending {};
1198+ // Set before op_add and unset after op_rms_norm to indicate that the add should
1199+ // use atomics to accumulate the square of the vector components
1200+ bool do_add_rms_atomic;
11861201
11871202 vk_buffer buffer_pool[MAX_VK_BUFFERS];
11881203
@@ -2921,8 +2936,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29212936
29222937 ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29232938 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2924- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2925- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2939+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true );
2940+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true );
29262941 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29272942 ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29282943
@@ -2992,20 +3007,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
29923007 };
29933008
29943009 bool rte = device->float_controls_rte_fp16;
2995- #define CREATE_BINARY(name, namemod, spec) \
3010+ #define CREATE_BINARY(name, namemod, spec, bindings ) \
29963011 for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
29973012 ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
29983013 #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
2999- "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3000-
3001- CREATE_BINARY(add, , {0})
3002- CREATE_BINARY(add, _norepeat, {1})
3003- CREATE_BINARY(sub, , {0})
3004- CREATE_BINARY(sub, _norepeat, {1})
3005- CREATE_BINARY(mul, , {0})
3006- CREATE_BINARY(mul, _norepeat, {1})
3007- CREATE_BINARY(div, , {0})
3008- CREATE_BINARY(div, _norepeat, {1})
3014+ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3015+
3016+ CREATE_BINARY(add, , {0}, 4)
3017+ CREATE_BINARY(add, _norepeat, {1}, 4)
3018+ CREATE_BINARY(sub, , {0}, 3)
3019+ CREATE_BINARY(sub, _norepeat, {1}, 3)
3020+ CREATE_BINARY(mul, , {0}, 3)
3021+ CREATE_BINARY(mul, _norepeat, {1}, 3)
3022+ CREATE_BINARY(div, , {0}, 3)
3023+ CREATE_BINARY(div, _norepeat, {1}, 3)
3024+ CREATE_BINARY(add_rms, , {0}, 4)
3025+ CREATE_BINARY(add_rms, _norepeat, {1}, 4)
30093026#undef CREATE_BINARY
30103027
30113028 if (device->multi_add) {
@@ -3286,6 +3303,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
32863303 device->coopmat_support = false;
32873304 device->integer_dot_product = false;
32883305 bool bfloat16_support = false;
3306+ bool atomic_float_support = false;
32893307
32903308 for (const auto& properties : ext_props) {
32913309 if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3325,6 +3343,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
33253343 !getenv("GGML_VK_DISABLE_BFLOAT16")) {
33263344 bfloat16_support = true;
33273345#endif
3346+ } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3347+ atomic_float_support = true;
33283348 }
33293349 }
33303350
@@ -3541,6 +3561,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
35413561 device_extensions.push_back("VK_KHR_shader_integer_dot_product");
35423562 }
35433563
3564+ VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3565+ atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3566+ if (atomic_float_support) {
3567+ last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3568+ last_struct = (VkBaseOutStructure *)&atomic_float_features;
3569+ device_extensions.push_back("VK_EXT_shader_atomic_float");
3570+ }
3571+
35443572 vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
35453573
35463574 device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3552,6 +3580,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
35523580#endif
35533581
35543582 device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3583+ device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
35553584
35563585 device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
35573586 device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
@@ -3872,6 +3901,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
38723901
38733902 device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
38743903
3904+ device->add_rms_fusion = !device->disable_fusion &&
3905+ device->subgroup_add &&
3906+ device->atomic_float_add;
3907+ device->atomic_binding_alignment =
3908+ std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
3909+
38753910 return device;
38763911 }
38773912
@@ -6914,10 +6949,19 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69146949 case GGML_OP_ADD:
69156950 {
69166951 if (ctx->num_additional_fused_ops > 0) {
6917- return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6952+ if (ctx->do_add_rms_atomic) {
6953+ return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
6954+ } else {
6955+ return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6956+ }
6957+ }
6958+ if (ctx->do_add_rms_atomic) {
6959+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
6960+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6961+ } else {
6962+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6963+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
69186964 }
6919- auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6920- return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
69216965 }
69226966 case GGML_OP_SUB:
69236967 {
@@ -7523,7 +7567,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
75237567 }
75247568 } break;
75257569 case GGML_OP_RMS_NORM:
7526- elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7570+ if (ctx->do_add_rms_atomic) {
7571+ // Run one element per thread, 128 threads per workgroup
7572+ elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7573+ } else {
7574+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7575+ }
75277576 break;
75287577
75297578 case GGML_OP_SUM:
@@ -7671,7 +7720,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
76717720 }
76727721 }
76737722
7674- if (op == GGML_OP_GLU) {
7723+ if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7724+ vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7725+ size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7726+ ggml_vk_sync_buffers(subctx);
7727+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7728+ { vk_subbuffer{ d_X, x_buf_offset, x_sz },
7729+ vk_subbuffer{ d_Y, y_buf_offset, y_sz },
7730+ vk_subbuffer{ d_D, d_buf_offset, d_sz },
7731+ vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
7732+ }, pc, elements);
7733+ } else if (op == GGML_OP_GLU) {
76757734 // Empty src1 is possible in glu, but the shader needs a buffer
76767735 vk_subbuffer subbuf_y;
76777736 if (use_src1) {
@@ -7884,7 +7943,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
78847943 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
78857944 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
78867945 0,
7887- 0.0f, 0.0f, 0 ,
7946+ 0.0f, 0.0f, ctx->do_add_rms_atomic ,
78887947 }, dryrun);
78897948}
78907949
@@ -8353,8 +8412,13 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
83538412 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
83548413 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
83558414 0,
8356- op_params[0], 0.0f, 0 ,
8415+ op_params[0], 0.0f, ctx->do_add_rms_atomic ,
83578416 }, dryrun);
8417+
8418+ if (ctx->do_add_rms_atomic) {
8419+ ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment;
8420+ ctx->do_add_rms_atomic = false;
8421+ }
83588422}
83598423
83608424static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9632,6 +9696,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
96329696 }
96339697 ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
96349698 }
9699+ if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) {
9700+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")");
9701+ // Resize buffer
9702+ if (ctx->prealloc_atomic_add != nullptr) {
9703+ ggml_vk_destroy_buffer(ctx->prealloc_atomic_add);
9704+ }
9705+ ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add);
9706+ }
96359707}
96369708
96379709static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9687,10 +9759,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96879759 return false;
96889760 }
96899761 break;
9762+ case GGML_OP_ADD:
9763+ if (node_idx + 1 < cgraph->n_nodes &&
9764+ cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
9765+ cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
9766+ ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9767+ ctx->device->add_rms_fusion) {
9768+ if (dryrun) {
9769+ ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment;
9770+ }
9771+ ctx->do_add_rms_atomic = true;
9772+ }
9773+ break;
96909774 case GGML_OP_REPEAT:
96919775 case GGML_OP_REPEAT_BACK:
96929776 case GGML_OP_GET_ROWS:
9693- case GGML_OP_ADD:
96949777 case GGML_OP_ADD_ID:
96959778 case GGML_OP_ACC:
96969779 case GGML_OP_SUB:
@@ -9808,6 +9891,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
98089891 // do the only thing needed for the dryrun.
98099892 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
98109893 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9894+ if (node->op == GGML_OP_RMS_NORM) {
9895+ ctx->do_add_rms_atomic = false;
9896+ }
98119897 return false;
98129898 }
98139899 default:
@@ -10782,6 +10868,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1078210868 vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
1078310869 }
1078410870
10871+ ctx->prealloc_size_atomic_add = 0;
10872+ ctx->prealloc_size_atomic_add_offset = 0;
10873+ ctx->do_add_rms_atomic = false;
10874+
1078510875 uint64_t total_mat_mul_bytes = 0;
1078610876 for (int i = 0; i < cgraph->n_nodes; i++) {
1078710877 if (!ctx->device->disable_fusion) {
@@ -10847,6 +10937,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1084710937 compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
1084810938 }
1084910939
10940+ if (ctx->prealloc_size_atomic_add) {
10941+ if (ctx->compute_ctx.expired()) {
10942+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10943+ ctx->compute_ctx = compute_ctx;
10944+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
10945+ } else {
10946+ compute_ctx = ctx->compute_ctx.lock();
10947+ }
10948+ // initialize atomic sums to zero.
10949+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add);
10950+ }
10951+
1085010952 // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
1085110953 // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
1085210954 // (and scaled down based on model size, so smaller models submit earlier).
0 commit comments