Skip to content

Commit d1e15c0

Browse files
committed
feat(ggml-cpu): Add dim arg to ggml_cumsum
With tests Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 3336f3c commit d1e15c0

File tree

4 files changed

+103
-10
lines changed

4 files changed

+103
-10
lines changed

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,14 @@ extern "C" {
988988
struct ggml_context * ctx,
989989
struct ggml_tensor * a);
990990

991+
// Cumulative sum along the specified dimension
991992
GGML_API struct ggml_tensor * ggml_cumsum(
993+
struct ggml_context * ctx,
994+
struct ggml_tensor * a,
995+
int dim);
996+
997+
// Convenience function: cumulative sum along dimension 0
998+
GGML_API struct ggml_tensor * ggml_cumsum_0(
992999
struct ggml_context * ctx,
9931000
struct ggml_tensor * a);
9941001

ggml/src/ggml-cpu/ops.cpp

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,50 @@ void ggml_compute_forward_sum(
13971397

13981398
// ggml_compute_forward_cumsum
13991399

1400+
// General implementation for arbitrary dimensions
1401+
template<typename T>
1402+
static void ggml_compute_forward_cumsum_general(
1403+
const ggml_compute_params * params,
1404+
ggml_tensor * dst,
1405+
int dim) {
1406+
1407+
const ggml_tensor * src0 = dst->src[0];
1408+
1409+
if (params->ith != 0) {
1410+
return;
1411+
}
1412+
1413+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
1414+
1415+
GGML_TENSOR_UNARY_OP_LOCALS
1416+
1417+
for (int64_t i3 = 0; i3 < ne03; i3++) {
1418+
for (int64_t i2 = 0; i2 < ne02; i2++) {
1419+
for (int64_t i1 = 0; i1 < ne01; i1++) {
1420+
for (int64_t i0 = 0; i0 < ne00; i0++) {
1421+
const T * src_ptr = (const T *)((const char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1422+
T * dst_ptr = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1423+
1424+
// Determine position in the cumsum dimension
1425+
int64_t i_vals[4] = {i0, i1, i2, i3};
1426+
int64_t i_dim = i_vals[dim];
1427+
1428+
if (i_dim == 0) {
1429+
// First element: just copy
1430+
dst_ptr[0] = src_ptr[0];
1431+
} else {
1432+
// Accumulate: add current value to previous cumsum value
1433+
const T * prev_dst_ptr = (const T *)((const char *) dst_ptr - dst->nb[dim]);
1434+
dst_ptr[0] = type_conversion_table<T>::from_f32(
1435+
type_conversion_table<T>::to_f32(prev_dst_ptr[0]) +
1436+
type_conversion_table<T>::to_f32(src_ptr[0]));
1437+
}
1438+
}
1439+
}
1440+
}
1441+
}
1442+
}
1443+
14001444
static void ggml_compute_forward_cumsum_f32(
14011445
const ggml_compute_params * params,
14021446
ggml_tensor * dst) {
@@ -1420,7 +1464,7 @@ static void ggml_compute_forward_cumsum_f32(
14201464
for (int64_t i3 = 0; i3 < ne03; i3++) {
14211465
for (int64_t i2 = 0; i2 < ne02; i2++) {
14221466
for (int64_t i1 = 0; i1 < ne01; i1++) {
1423-
float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1467+
const float * src_row = (const float *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
14241468
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
14251469
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
14261470
}
@@ -1451,7 +1495,7 @@ static void ggml_compute_forward_cumsum_f16(
14511495
for (int64_t i3 = 0; i3 < ne03; i3++) {
14521496
for (int64_t i2 = 0; i2 < ne02; i2++) {
14531497
for (int64_t i1 = 0; i1 < ne01; i1++) {
1454-
ggml_fp16_t * src_row = (ggml_fp16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1498+
const ggml_fp16_t * src_row = (const ggml_fp16_t *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
14551499
ggml_fp16_t * dst_row = (ggml_fp16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
14561500
ggml_vec_cumsum_f16(ne00, dst_row, src_row);
14571501
}
@@ -1482,7 +1526,7 @@ static void ggml_compute_forward_cumsum_bf16(
14821526
for (int64_t i3 = 0; i3 < ne03; i3++) {
14831527
for (int64_t i2 = 0; i2 < ne02; i2++) {
14841528
for (int64_t i1 = 0; i1 < ne01; i1++) {
1485-
ggml_bf16_t * src_row = (ggml_bf16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1529+
const ggml_bf16_t * src_row = (const ggml_bf16_t *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
14861530
ggml_bf16_t * dst_row = (ggml_bf16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
14871531
ggml_vec_cumsum_bf16(ne00, dst_row, src_row);
14881532
}
@@ -1496,18 +1540,33 @@ void ggml_compute_forward_cumsum(
14961540

14971541
const ggml_tensor * src0 = dst->src[0];
14981542

1543+
const int dim = ggml_get_op_params_i32(dst, 0);
1544+
const bool use_general = dim != 0 || !ggml_is_contiguous_rows(src0);
1545+
14991546
switch (src0->type) {
15001547
case GGML_TYPE_F32:
15011548
{
1502-
ggml_compute_forward_cumsum_f32(params, dst);
1549+
if (use_general) {
1550+
ggml_compute_forward_cumsum_general<float>(params, dst, dim);
1551+
} else {
1552+
ggml_compute_forward_cumsum_f32(params, dst);
1553+
}
15031554
} break;
15041555
case GGML_TYPE_F16:
15051556
{
1506-
ggml_compute_forward_cumsum_f16(params, dst);
1557+
if (use_general) {
1558+
ggml_compute_forward_cumsum_general<ggml_fp16_t>(params, dst, dim);
1559+
} else {
1560+
ggml_compute_forward_cumsum_f16(params, dst);
1561+
}
15071562
} break;
15081563
case GGML_TYPE_BF16:
15091564
{
1510-
ggml_compute_forward_cumsum_bf16(params, dst);
1565+
if (use_general) {
1566+
ggml_compute_forward_cumsum_general<ggml_bf16_t>(params, dst, dim);
1567+
} else {
1568+
ggml_compute_forward_cumsum_bf16(params, dst);
1569+
}
15111570
} break;
15121571
default:
15131572
{

ggml/src/ggml.c

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2346,16 +2346,27 @@ struct ggml_tensor * ggml_sum_rows(
23462346

23472347
struct ggml_tensor * ggml_cumsum(
23482348
struct ggml_context * ctx,
2349-
struct ggml_tensor * a) {
2349+
struct ggml_tensor * a,
2350+
int dim) {
2351+
2352+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
23502353

23512354
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne);
23522355

2356+
ggml_set_op_params_i32(result, 0, dim);
2357+
23532358
result->op = GGML_OP_CUMSUM;
23542359
result->src[0] = a;
23552360

23562361
return result;
23572362
}
23582363

2364+
struct ggml_tensor * ggml_cumsum_0(
2365+
struct ggml_context * ctx,
2366+
struct ggml_tensor * a) {
2367+
return ggml_cumsum(ctx, a, 0);
2368+
}
2369+
23592370
// ggml_mean
23602371

23612372
struct ggml_tensor * ggml_mean(

tests/test-backend-ops.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4861,14 +4861,18 @@ struct test_sum_rows : public test_case {
48614861
struct test_cumsum : public test_case {
48624862
const ggml_type type;
48634863
const std::array<int64_t, 4> ne;
4864+
const int64_t dim;
4865+
const std::array<int64_t, 4> permute;
48644866

48654867
std::string vars() override {
48664868
return VARS_TO_STR2(type, ne);
48674869
}
48684870

48694871
test_cumsum(ggml_type type = GGML_TYPE_F32,
4870-
std::array<int64_t, 4> ne = {10, 5, 4, 3})
4871-
: type(type), ne(ne) {}
4872+
std::array<int64_t, 4> ne = {10, 5, 4, 3},
4873+
int64_t dim = 0,
4874+
std::array<int64_t, 4> permute = {-1, -1, -1, -1})
4875+
: type(type), ne(ne), dim(dim), permute(permute) {}
48724876

48734877

48744878
double max_nmse_err() override {
@@ -4884,7 +4888,11 @@ struct test_cumsum : public test_case {
48844888
ggml_set_param(a);
48854889
ggml_set_name(a, "a");
48864890

4887-
ggml_tensor * out = ggml_cumsum(ctx, a);
4891+
if (permute[0] != -1) {
4892+
a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
4893+
}
4894+
4895+
ggml_tensor * out = ggml_cumsum(ctx, a, dim);
48884896
ggml_set_name(out, "out");
48894897

48904898
return out;
@@ -7056,6 +7064,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(int verbose
70567064
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 }));
70577065
test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 }));
70587066
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 }));
7067+
// non-contiguous
7068+
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 0, {1, 0, 2, 3}));
7069+
// alternate dim
7070+
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 1));
70597071

70607072
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));
70617073
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));
@@ -7233,6 +7245,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
72337245
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 }));
72347246
test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 }));
72357247
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 }));
7248+
// non-contiguous
7249+
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 0, {1, 0, 2, 3}));
7250+
// alternate dim
7251+
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2, 4, 2, 1 }, 1));
72367252

72377253
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));
72387254
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));

0 commit comments

Comments
 (0)