Skip to content

Commit 0012b5c

Browse files
committed
vulkan : support ggml_mean
1 parent 21c17b5 commit 0012b5c

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7207,6 +7207,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
72077207
return nullptr;
72087208
case GGML_OP_SUM:
72097209
case GGML_OP_SUM_ROWS:
7210+
case GGML_OP_MEAN:
72107211
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
72117212
return ctx->device->pipeline_sum_rows_f32;
72127213
}
@@ -7554,6 +7555,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
75547555
case GGML_OP_SOFT_MAX:
75557556
case GGML_OP_SOFT_MAX_BACK:
75567557
case GGML_OP_SUM_ROWS:
7558+
case GGML_OP_MEAN:
75577559
case GGML_OP_ARGMAX:
75587560
{
75597561
const uint32_t nr = ggml_nrows(src0);
@@ -8540,11 +8542,15 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
85408542
}
85418543

85428544
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8543-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
8545+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 1.0f, 0.0f }, dryrun);
85448546
}
85458547

85468548
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8547-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
8549+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 1.0f, 0.0f }, dryrun);
8550+
}
8551+
8552+
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8553+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, { (uint32_t)src0->ne[0], 0, 1.0f / (float)src0->ne[0], 0.0f }, dryrun);
85488554
}
85498555

85508556
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9766,6 +9772,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97669772
case GGML_OP_ARGSORT:
97679773
case GGML_OP_SUM:
97689774
case GGML_OP_SUM_ROWS:
9775+
case GGML_OP_MEAN:
97699776
case GGML_OP_ARGMAX:
97709777
case GGML_OP_COUNT_EQUAL:
97719778
case GGML_OP_IM2COL:
@@ -9835,6 +9842,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
98359842
case GGML_OP_ARGSORT:
98369843
case GGML_OP_SUM:
98379844
case GGML_OP_SUM_ROWS:
9845+
case GGML_OP_MEAN:
98389846
case GGML_OP_ARGMAX:
98399847
case GGML_OP_COUNT_EQUAL:
98409848
case GGML_OP_IM2COL:
@@ -10037,6 +10045,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1003710045
case GGML_OP_SUM_ROWS:
1003810046
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
1003910047

10048+
break;
10049+
case GGML_OP_MEAN:
10050+
ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun);
10051+
1004010052
break;
1004110053
case GGML_OP_ARGMAX:
1004210054
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
@@ -10196,6 +10208,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1019610208
case GGML_OP_ARGSORT:
1019710209
case GGML_OP_SUM:
1019810210
case GGML_OP_SUM_ROWS:
10211+
case GGML_OP_MEAN:
1019910212
case GGML_OP_ARGMAX:
1020010213
case GGML_OP_COUNT_EQUAL:
1020110214
case GGML_OP_IM2COL:
@@ -11428,6 +11441,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1142811441
case GGML_OP_SOFT_MAX_BACK:
1142911442
case GGML_OP_SUM:
1143011443
case GGML_OP_SUM_ROWS:
11444+
case GGML_OP_MEAN:
1143111445
case GGML_OP_ARGMAX:
1143211446
case GGML_OP_COUNT_EQUAL:
1143311447
case GGML_OP_IM2COL:
@@ -11983,6 +11997,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1198311997
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
1198411998
} else if (tensor->op == GGML_OP_SUM_ROWS) {
1198511999
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
12000+
} else if (tensor->op == GGML_OP_MEAN) {
12001+
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
1198612002
} else if (tensor->op == GGML_OP_ARGMAX) {
1198712003
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
1198812004
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {

ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ shared FLOAT_TYPE tmp[BLOCK_SIZE];
1616
void main() {
1717
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
1818
const uint col = gl_LocalInvocationID.x;
19+
const float weight = p.param1;
1920

2021
tmp[col] = FLOAT_TYPE(0.0f);
2122

@@ -32,6 +33,6 @@ void main() {
3233
}
3334

3435
if (col == 0) {
35-
data_d[row] = D_TYPE(tmp[0]);
36+
data_d[row] = D_TYPE(tmp[0] * weight);
3637
}
3738
}

0 commit comments

Comments
 (0)