@@ -1397,6 +1397,50 @@ void ggml_compute_forward_sum(
13971397
13981398// ggml_compute_forward_cumsum
13991399
1400+ // General implementation for arbitrary dimensions
1401+ template <typename T>
1402+ static void ggml_compute_forward_cumsum_general (
1403+ const ggml_compute_params * params,
1404+ ggml_tensor * dst,
1405+ int dim) {
1406+
1407+ const ggml_tensor * src0 = dst->src [0 ];
1408+
1409+ if (params->ith != 0 ) {
1410+ return ;
1411+ }
1412+
1413+ GGML_ASSERT (dim >= 0 && dim < GGML_MAX_DIMS);
1414+
1415+ GGML_TENSOR_UNARY_OP_LOCALS
1416+
1417+ for (int64_t i3 = 0 ; i3 < ne03; i3++) {
1418+ for (int64_t i2 = 0 ; i2 < ne02; i2++) {
1419+ for (int64_t i1 = 0 ; i1 < ne01; i1++) {
1420+ for (int64_t i0 = 0 ; i0 < ne00; i0++) {
1421+ const T * src_ptr = (const T *)((const char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1422+ T * dst_ptr = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1423+
1424+ // Determine position in the cumsum dimension
1425+ int64_t i_vals[4 ] = {i0, i1, i2, i3};
1426+ int64_t i_dim = i_vals[dim];
1427+
1428+ if (i_dim == 0 ) {
1429+ // First element: just copy
1430+ dst_ptr[0 ] = src_ptr[0 ];
1431+ } else {
1432+ // Accumulate: add current value to previous cumsum value
1433+ const T * prev_dst_ptr = (const T *)((const char *) dst_ptr - dst->nb [dim]);
1434+ dst_ptr[0 ] = type_conversion_table<T>::from_f32 (
1435+ type_conversion_table<T>::to_f32 (prev_dst_ptr[0 ]) +
1436+ type_conversion_table<T>::to_f32 (src_ptr[0 ]));
1437+ }
1438+ }
1439+ }
1440+ }
1441+ }
1442+ }
1443+
14001444static void ggml_compute_forward_cumsum_f32 (
14011445 const ggml_compute_params * params,
14021446 ggml_tensor * dst) {
@@ -1420,7 +1464,7 @@ static void ggml_compute_forward_cumsum_f32(
14201464 for (int64_t i3 = 0 ; i3 < ne03; i3++) {
14211465 for (int64_t i2 = 0 ; i2 < ne02; i2++) {
14221466 for (int64_t i1 = 0 ; i1 < ne01; i1++) {
1423- float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1467+ const float * src_row = (const float *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
14241468 float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
14251469 ggml_vec_cumsum_f32 (ne00, dst_row, src_row);
14261470 }
@@ -1451,7 +1495,7 @@ static void ggml_compute_forward_cumsum_f16(
14511495 for (int64_t i3 = 0 ; i3 < ne03; i3++) {
14521496 for (int64_t i2 = 0 ; i2 < ne02; i2++) {
14531497 for (int64_t i1 = 0 ; i1 < ne01; i1++) {
1454- ggml_fp16_t * src_row = (ggml_fp16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1498+ const ggml_fp16_t * src_row = (const ggml_fp16_t *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
14551499 ggml_fp16_t * dst_row = (ggml_fp16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
14561500 ggml_vec_cumsum_f16 (ne00, dst_row, src_row);
14571501 }
@@ -1482,7 +1526,7 @@ static void ggml_compute_forward_cumsum_bf16(
14821526 for (int64_t i3 = 0 ; i3 < ne03; i3++) {
14831527 for (int64_t i2 = 0 ; i2 < ne02; i2++) {
14841528 for (int64_t i1 = 0 ; i1 < ne01; i1++) {
1485- ggml_bf16_t * src_row = (ggml_bf16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1529+ const ggml_bf16_t * src_row = (const ggml_bf16_t *) ((const char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
14861530 ggml_bf16_t * dst_row = (ggml_bf16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
14871531 ggml_vec_cumsum_bf16 (ne00, dst_row, src_row);
14881532 }
@@ -1496,18 +1540,33 @@ void ggml_compute_forward_cumsum(
14961540
14971541 const ggml_tensor * src0 = dst->src [0 ];
14981542
1543+ const int dim = ggml_get_op_params_i32 (dst, 0 );
1544+ const bool use_general = dim != 0 || !ggml_is_contiguous_rows (src0);
1545+
14991546 switch (src0->type ) {
15001547 case GGML_TYPE_F32:
15011548 {
1502- ggml_compute_forward_cumsum_f32 (params, dst);
1549+ if (use_general) {
1550+ ggml_compute_forward_cumsum_general<float >(params, dst, dim);
1551+ } else {
1552+ ggml_compute_forward_cumsum_f32 (params, dst);
1553+ }
15031554 } break ;
15041555 case GGML_TYPE_F16:
15051556 {
1506- ggml_compute_forward_cumsum_f16 (params, dst);
1557+ if (use_general) {
1558+ ggml_compute_forward_cumsum_general<ggml_fp16_t >(params, dst, dim);
1559+ } else {
1560+ ggml_compute_forward_cumsum_f16 (params, dst);
1561+ }
15071562 } break ;
15081563 case GGML_TYPE_BF16:
15091564 {
1510- ggml_compute_forward_cumsum_bf16 (params, dst);
1565+ if (use_general) {
1566+ ggml_compute_forward_cumsum_general<ggml_bf16_t >(params, dst, dim);
1567+ } else {
1568+ ggml_compute_forward_cumsum_bf16 (params, dst);
1569+ }
15111570 } break ;
15121571 default :
15131572 {
0 commit comments