@@ -705,6 +705,7 @@ struct vk_device_struct {
705705 vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
706706 vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
707707 vk_pipeline pipeline_sum_rows_f32;
708+ vk_pipeline pipeline_cumsum_f32;
708709 vk_pipeline pipeline_argmax_f32;
709710 vk_pipeline pipeline_count_equal_i32;
710711 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -3968,6 +3969,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
39683969
39693970 ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
39703971
3972+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
3973+
39713974 ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
39723975
39733976#define IM2COL(bda) \
@@ -8457,6 +8460,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
84578460 return ctx->device->pipeline_sum_rows_f32;
84588461 }
84598462 return nullptr;
8463+ case GGML_OP_CUMSUM:
8464+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8465+ return ctx->device->pipeline_cumsum_f32;
8466+ }
8467+ return nullptr;
84608468 case GGML_OP_ARGMAX:
84618469 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
84628470 return ctx->device->pipeline_argmax_f32;
@@ -8821,6 +8829,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88218829 case GGML_OP_SOFT_MAX:
88228830 case GGML_OP_SOFT_MAX_BACK:
88238831 case GGML_OP_SUM_ROWS:
8832+ case GGML_OP_CUMSUM:
88248833 case GGML_OP_MEAN:
88258834 case GGML_OP_ARGMAX:
88268835 {
@@ -10150,6 +10159,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1015010159 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
1015110160}
1015210161
10162+ static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10163+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
10164+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
10165+ }
10166+
1015310167static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1015410168 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
1015510169}
@@ -11749,6 +11763,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1174911763 case GGML_OP_SUM_ROWS:
1175011764 ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
1175111765
11766+ break;
11767+ case GGML_OP_CUMSUM:
11768+ ggml_vk_cumsum(ctx, compute_ctx, src0, node);
11769+
1175211770 break;
1175311771 case GGML_OP_MEAN:
1175411772 ggml_vk_mean(ctx, compute_ctx, src0, node);
@@ -13786,6 +13804,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1378613804 case GGML_OP_SUM_ROWS:
1378713805 case GGML_OP_MEAN:
1378813806 return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
13807+ case GGML_OP_CUMSUM:
13808+ {
13809+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13810+ auto device = ggml_vk_get_device(ctx->device);
13811+ if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
13812+ return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
13813+ }
13814+ return false;
13815+ }
1378913816 case GGML_OP_ARGMAX:
1379013817 case GGML_OP_COUNT_EQUAL:
1379113818 case GGML_OP_IM2COL:
@@ -14436,6 +14463,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1443614463 tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
1443714464 } else if (tensor->op == GGML_OP_SUM_ROWS) {
1443814465 tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
14466+ } else if (tensor->op == GGML_OP_CUMSUM) {
14467+ tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
1443914468 } else if (tensor->op == GGML_OP_MEAN) {
1444014469 tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
1444114470 } else if (tensor->op == GGML_OP_ARGMAX) {
0 commit comments