@@ -102,9 +102,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
102
102
103
103
struct ggml_backend_vk_context;
104
104
105
- #define MAX_PARAMETER_COUNT 8
105
+ #define MAX_PARAMETER_COUNT 12
106
106
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
107
- #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2 )
107
+ #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3 )
108
108
109
109
struct vk_pipeline_struct {
110
110
std::string name;
@@ -459,6 +459,7 @@ struct vk_device_struct {
459
459
460
460
// indexed by num_additional_fused_ops == num_adds - 1
461
461
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
462
+ vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
462
463
463
464
vk_pipeline pipeline_add_id_f32;
464
465
@@ -819,8 +820,13 @@ struct vk_op_multi_add_push_constants {
819
820
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
820
821
821
822
// strides for srcs+dst
822
- uint32_t nb[8][4];
823
+ uint32_t nb[MAX_PARAMETER_COUNT][4];
824
+
825
+ uint32_t rms_partials;
823
826
};
827
+ // update multi_add.comp if this changes
828
+ static_assert(MAX_PARAMETER_COUNT == 12);
829
+ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
824
830
825
831
struct vk_op_add_id_push_constants {
826
832
uint32_t ne0;
@@ -3032,7 +3038,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
3032
3038
3033
3039
if (device->multi_add) {
3034
3040
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
3035
- ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3041
+ ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3042
+ ggml_vk_create_pipeline(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3036
3043
}
3037
3044
}
3038
3045
@@ -6912,7 +6919,7 @@ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst)
6912
6919
return elements;
6913
6920
}
6914
6921
6915
- static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
6922
+ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
6916
6923
switch (op) {
6917
6924
case GGML_OP_GET_ROWS:
6918
6925
GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -7836,7 +7843,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
7836
7843
const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
7837
7844
uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
7838
7845
uint32_t num_tensors = num_srcs + 1;
7839
- GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);
7846
+ GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
7840
7847
7841
7848
tensors[0] = first_node->src[0];
7842
7849
tensors[1] = first_node->src[1];
@@ -7863,8 +7870,9 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
7863
7870
pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
7864
7871
pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
7865
7872
}
7873
+ pc.rms_partials = ctx->do_add_rms_partials;
7866
7874
7867
- vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops] ;
7875
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline( ctx, tensors[0], tensors[1], nullptr, dst, dst->op) ;
7868
7876
7869
7877
if (pipeline == nullptr) {
7870
7878
std::cerr << "ggml_vulkan: Error: Missing multi_add";
@@ -7902,6 +7910,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
7902
7910
buf[i] = buf[0];
7903
7911
offset[i] = 0;
7904
7912
}
7913
+ if (ctx->do_add_rms_partials) {
7914
+ buf[num_tensors] = ctx->prealloc_add_rms_partials;
7915
+ offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
7916
+ }
7905
7917
7906
7918
std::array<uint32_t, 3> elements;
7907
7919
@@ -7915,6 +7927,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
7915
7927
}
7916
7928
7917
7929
ggml_vk_sync_buffers(subctx);
7930
+ static_assert(MAX_PARAMETER_COUNT == 12);
7918
7931
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7919
7932
{
7920
7933
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
@@ -7925,6 +7938,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
7925
7938
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
7926
7939
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
7927
7940
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
7941
+ vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
7942
+ vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
7943
+ vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
7944
+ vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
7928
7945
}, pc, elements);
7929
7946
}
7930
7947
@@ -9771,17 +9788,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9771
9788
}
9772
9789
break;
9773
9790
case GGML_OP_ADD:
9774
- if (node_idx + 1 < cgraph->n_nodes &&
9775
- cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
9776
- cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
9777
- ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9778
- ctx->device->add_rms_fusion) {
9779
- if (dryrun) {
9780
- ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
9791
+ {
9792
+ int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
9793
+ if (next_node_idx < cgraph->n_nodes &&
9794
+ cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
9795
+ cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
9796
+ ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
9797
+ ctx->device->add_rms_fusion) {
9798
+ if (dryrun) {
9799
+ ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
9800
+ }
9801
+ ctx->do_add_rms_partials = true;
9781
9802
}
9782
- ctx->do_add_rms_partials = true;
9783
- }
9784
- break;
9803
+ } break;
9785
9804
case GGML_OP_REPEAT:
9786
9805
case GGML_OP_REPEAT_BACK:
9787
9806
case GGML_OP_GET_ROWS:
0 commit comments