@@ -1018,7 +1018,7 @@ struct vk_op_sum_rows_push_constants
10181018{
10191019    uint32_t n_cols;
10201020    uint32_t ne01, ne02;
1021-     uint32_t nb00,  nb01, nb02, nb03;
1021+     uint32_t nb01, nb02, nb03;
10221022    uint32_t nb11, nb12, nb13;
10231023    float weight;
10241024    uint32_t misalign_offsets;
@@ -1032,7 +1032,6 @@ vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tens
10321032    p.n_cols = (uint32_t)n_cols;
10331033    p.ne01 = (uint32_t)src->ne[1];
10341034    p.ne02 = (uint32_t)src->ne[2];
1035-     p.nb00 = (uint32_t)src->nb[0] / type_size;
10361035    p.nb01 = (uint32_t)src->nb[1] / type_size;
10371036    p.nb02 = (uint32_t)src->nb[2] / type_size;
10381037    p.nb03 = (uint32_t)src->nb[3] / type_size;
@@ -8590,7 +8589,6 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
85908589
85918590static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
85928591    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
8593-     p.nb00 = 1; // treat src0 as flattened 1D tensor
85948592    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
85958593}
85968594
@@ -11491,9 +11489,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1149111489        case GGML_OP_DIAG_MASK_INF:
1149211490        case GGML_OP_SOFT_MAX:
1149311491        case GGML_OP_SOFT_MAX_BACK:
11492+             return true;
1149411493        case GGML_OP_SUM:
1149511494        case GGML_OP_SUM_ROWS:
1149611495        case GGML_OP_MEAN:
11496+             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
1149711497        case GGML_OP_ARGMAX:
1149811498        case GGML_OP_COUNT_EQUAL:
1149911499        case GGML_OP_IM2COL:
0 commit comments