@@ -1014,6 +1014,40 @@ struct vk_op_upscale_push_constants {
10141014 float sf0; float sf1; float sf2; float sf3;
10151015};
10161016
1017+ struct vk_op_sum_rows_push_constants
1018+ {
1019+ uint32_t n_cols;
1020+ uint32_t ne01, ne02;
1021+ uint32_t nb00, nb01, nb02, nb03;
1022+ uint32_t nb11, nb12, nb13;
1023+ float weight;
1024+ uint32_t misalign_offsets;
1025+ uint32_t ne0_12mp, ne0_12L;
1026+ uint32_t ne0_1mp, ne0_1L;
1027+ };
1028+
1029+ vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
1030+ uint32_t type_size = (uint32_t)ggml_type_size(src->type);
1031+ vk_op_sum_rows_push_constants p = {};
1032+ p.n_cols = (uint32_t)n_cols;
1033+ p.ne01 = (uint32_t)src->ne[1];
1034+ p.ne02 = (uint32_t)src->ne[2];
1035+ p.nb00 = (uint32_t)src->nb[0] / type_size;
1036+ p.nb01 = (uint32_t)src->nb[1] / type_size;
1037+ p.nb02 = (uint32_t)src->nb[2] / type_size;
1038+ p.nb03 = (uint32_t)src->nb[3] / type_size;
1039+ p.nb11 = (uint32_t)dst->nb[1] / type_size;
1040+ p.nb12 = (uint32_t)dst->nb[2] / type_size;
1041+ p.nb13 = (uint32_t)dst->nb[3] / type_size;
1042+ p.weight = 1.0f;
1043+ return p;
1044+ }
1045+
1046+ template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
1047+ init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
1048+ init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
1049+ }
1050+
10171051// Allow pre-recording command buffers
10181052struct vk_staging_memcpy {
10191053 vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -3122,7 +3156,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
31223156
31233157 ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
31243158
3125- ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);
3159+ ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);
31263160
31273161 ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
31283162
@@ -7340,6 +7374,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
73407374 case GGML_OP_CONV_2D_DW:
73417375 case GGML_OP_IM2COL:
73427376 case GGML_OP_SET_ROWS:
7377+ case GGML_OP_SUM:
7378+ case GGML_OP_SUM_ROWS:
7379+ case GGML_OP_MEAN:
73437380 return true;
73447381 default:
73457382 return false;
@@ -7374,6 +7411,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
73747411 GGML_UNUSED(src2);
73757412}
73767413
7414+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7415+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
7416+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
7417+
7418+ p.misalign_offsets = (a_offset << 16) | d_offset;
7419+
7420+ GGML_UNUSED(src1);
7421+ GGML_UNUSED(src2);
7422+ }
7423+
73777424template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
73787425 const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
73797426 const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@@ -8542,15 +8589,20 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
85428589}
85438590
85448591static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8545- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 1.0f, 0.0f }, dryrun);
8592+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
8593+ p.nb00 = 1; // treat src0 as flattened 1D tensor
8594+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
85468595}
85478596
85488597static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8549- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 1.0f, 0.0f }, dryrun);
8598+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
8599+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
85508600}
85518601
85528602static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8553- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, { (uint32_t)src0->ne[0], 0, 1.0f / (float)src0->ne[0], 0.0f }, dryrun);
8603+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
8604+ p.weight = 1.0f / (float)src0->ne[0];
8605+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
85548606}
85558607
85568608static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
0 commit comments