@@ -7207,6 +7207,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
72077207 return nullptr;
72087208 case GGML_OP_SUM:
72097209 case GGML_OP_SUM_ROWS:
7210+ case GGML_OP_MEAN:
72107211 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
72117212 return ctx->device->pipeline_sum_rows_f32;
72127213 }
@@ -7554,6 +7555,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
75547555 case GGML_OP_SOFT_MAX:
75557556 case GGML_OP_SOFT_MAX_BACK:
75567557 case GGML_OP_SUM_ROWS:
7558+ case GGML_OP_MEAN:
75577559 case GGML_OP_ARGMAX:
75587560 {
75597561 const uint32_t nr = ggml_nrows(src0);
@@ -8540,11 +8542,15 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
85408542}
85418543
85428544static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8543- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0 .0f, 0.0f }, dryrun);
8545+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 1 .0f, 0.0f }, dryrun);
85448546}
85458547
85468548static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8547- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
8549+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 1.0f, 0.0f }, dryrun);
8550+ }
8551+
8552+ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8553+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, { (uint32_t)src0->ne[0], 0, 1.0f / (float)src0->ne[0], 0.0f }, dryrun);
85488554}
85498555
85508556static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9766,6 +9772,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97669772 case GGML_OP_ARGSORT:
97679773 case GGML_OP_SUM:
97689774 case GGML_OP_SUM_ROWS:
9775+ case GGML_OP_MEAN:
97699776 case GGML_OP_ARGMAX:
97709777 case GGML_OP_COUNT_EQUAL:
97719778 case GGML_OP_IM2COL:
@@ -9835,6 +9842,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
98359842 case GGML_OP_ARGSORT:
98369843 case GGML_OP_SUM:
98379844 case GGML_OP_SUM_ROWS:
9845+ case GGML_OP_MEAN:
98389846 case GGML_OP_ARGMAX:
98399847 case GGML_OP_COUNT_EQUAL:
98409848 case GGML_OP_IM2COL:
@@ -10037,6 +10045,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1003710045 case GGML_OP_SUM_ROWS:
1003810046 ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
1003910047
10048+ break;
10049+ case GGML_OP_MEAN:
10050+ ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun);
10051+
1004010052 break;
1004110053 case GGML_OP_ARGMAX:
1004210054 ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
@@ -10196,6 +10208,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1019610208 case GGML_OP_ARGSORT:
1019710209 case GGML_OP_SUM:
1019810210 case GGML_OP_SUM_ROWS:
10211+ case GGML_OP_MEAN:
1019910212 case GGML_OP_ARGMAX:
1020010213 case GGML_OP_COUNT_EQUAL:
1020110214 case GGML_OP_IM2COL:
@@ -11428,6 +11441,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1142811441 case GGML_OP_SOFT_MAX_BACK:
1142911442 case GGML_OP_SUM:
1143011443 case GGML_OP_SUM_ROWS:
11444+ case GGML_OP_MEAN:
1143111445 case GGML_OP_ARGMAX:
1143211446 case GGML_OP_COUNT_EQUAL:
1143311447 case GGML_OP_IM2COL:
@@ -11983,6 +11997,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1198311997 tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
1198411998 } else if (tensor->op == GGML_OP_SUM_ROWS) {
1198511999 tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
12000+ } else if (tensor->op == GGML_OP_MEAN) {
12001+ tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
1198612002 } else if (tensor->op == GGML_OP_ARGMAX) {
1198712003 tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
1198812004 } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
0 commit comments