11#include " cumsum.cuh"
22
3- // Kernel to compute cumulative sum along the innermost dimension (ne[0])
4- // Each block processes one row (ne[0] elements)
3+ // Kernel to compute cumulative sum along an arbitrary dimension
4+ // Each block processes one position in the non-cumsum dimensions
55// Algorithm matches Metal implementation:
66// 1. Each warp computes prefix sum within itself
77// 2. Last thread of each warp stores result in shared memory
@@ -13,36 +13,60 @@ static __global__ void cumsum_kernel(
1313 const T * src, T * dst,
1414 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
1515 const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
16- const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
16+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
17+ const int dim) {
1718
1819 // Shared memory to store warp sums (always use float for accumulation)
1920 extern __shared__ float shmem[];
2021
21- const int64_t i3 = blockIdx .z ;
22- const int64_t i2 = blockIdx .y ;
23- const int64_t i1 = blockIdx .x ;
22+ // Map block indices to actual tensor dimensions
23+ // blockIdx.x, blockIdx.y, blockIdx.z represent the 3 non-cumsum dimensions
24+ // threadIdx.x represents position in the cumsum dimension
25+ int64_t grid_indices[3 ] = {blockIdx .x , blockIdx .y , blockIdx .z };
26+ int64_t i_vals[4 ];
27+
28+ int grid_idx = 0 ;
29+ for (int d = 0 ; d < 4 ; ++d) {
30+ if (d == dim) {
31+ i_vals[d] = 0 ; // Will be set in the loop below
32+ } else {
33+ i_vals[d] = grid_indices[grid_idx++];
34+ }
35+ }
36+
37+ const int64_t i0 = i_vals[0 ];
38+ const int64_t i1 = i_vals[1 ];
39+ const int64_t i2 = i_vals[2 ];
40+ const int64_t i3 = i_vals[3 ];
2441
25- if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
42+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01 || i0 >= ne00 ) {
2643 return ;
2744 }
2845
29- const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03);
30- T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
46+ const int64_t ne_dim = (dim == 0 ) ? ne00 : (dim == 1 ) ? ne01 : (dim == 2 ) ? ne02 : ne03;
47+ const int64_t nb_dim_src = (dim == 0 ) ? nb00 : (dim == 1 ) ? nb01 : (dim == 2 ) ? nb02 : nb03;
48+ const int64_t nb_dim_dst = (dim == 0 ) ? nb0 : (dim == 1 ) ? nb1 : (dim == 2 ) ? nb2 : nb3;
3149
3250 const int tid = threadIdx .x ;
3351 const int lane_id = tid % WARP_SIZE;
3452
3553 // Phase 1: Each thread processes elements at stride blockDim.x
3654 // Compute warp-level prefix sums
37- for (int64_t i0 = tid; i0 < ne00; i0 += blockDim .x ) {
55+ for (int64_t i_dim = tid; i_dim < ne_dim; i_dim += blockDim .x ) {
56+ const int64_t offset_src = i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03 + i_dim*nb_dim_src;
57+ const int64_t offset_dst = i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 + i_dim*nb_dim_dst;
58+
59+ const T * src_ptr = (const T *) ((const char *) src + offset_src);
60+ T * dst_ptr = (T *) (( char *) dst + offset_dst);
61+
3862 // Load value and compute prefix sum within warp
39- float val = static_cast <float >(src_row[i0 ]);
63+ float val = static_cast <float >(src_ptr[ 0 ]);
4064 val = warp_prefix_inclusive_sum (val);
41- dst_row[i0 ] = static_cast <T>(val);
65+ dst_ptr[ 0 ] = static_cast <T>(val);
4266
4367 // Last thread of warp stores its sum to shared memory at position based on data index
44- if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1 ) {
45- const int shmem_idx = i0 / WARP_SIZE;
68+ if (lane_id == WARP_SIZE - 1 || i_dim == ne_dim - 1 ) {
69+ const int shmem_idx = i_dim / WARP_SIZE;
4670 shmem[shmem_idx] = val;
4771 }
4872 }
@@ -51,13 +75,16 @@ static __global__ void cumsum_kernel(
5175 __syncthreads ();
5276
5377 // Phase 2: Add the sum of all preceding warp groups to each element
54- for (int64_t i0 = tid; i0 < ne00; i0 += blockDim .x ) {
55- const int shmem_idx = i0 / WARP_SIZE;
78+ for (int64_t i_dim = tid; i_dim < ne_dim; i_dim += blockDim .x ) {
79+ const int64_t offset_dst = i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 + i_dim*nb_dim_dst;
80+ T * dst_ptr = (T *) ((char *) dst + offset_dst);
81+
82+ const int shmem_idx = i_dim / WARP_SIZE;
5683 float sum = 0 .0f ;
5784 for (int j = 0 ; j < shmem_idx; ++j) {
5885 sum += shmem[j];
5986 }
60- dst_row[i0 ] = static_cast <T>(static_cast <float >(dst_row[i0 ]) + sum);
87+ dst_ptr[ 0 ] = static_cast <T>(static_cast <float >(dst_ptr[ 0 ]) + sum);
6188 }
6289}
6390
@@ -67,27 +94,44 @@ static void cumsum_cuda(
6794 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
6895 const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
6996 const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
97+ const int dim,
7098 cudaStream_t stream) {
7199
100+ // Dimension being accumulated
101+ const int64_t ne_dims[4 ] = {ne00, ne01, ne02, ne03};
102+ const int64_t ne_dim = ne_dims[dim];
103+
104+ // Grid dimensions: the GGML_MAX_DIMS-1 non-cumsum dimensions
105+ int64_t grid_dims_arr[GGML_MAX_DIMS - 1 ];
106+ int grid_idx = 0 ;
107+ for (int d = 0 ; d < GGML_MAX_DIMS; ++d) {
108+ if (d != dim) {
109+ grid_dims_arr[grid_idx++] = ne_dims[d];
110+ }
111+ }
112+
72113 dim3 block_dims (CUDA_CUMSUM_BLOCK_SIZE, 1 , 1 );
73- dim3 grid_dims (ne01, ne02, ne03 );
114+ dim3 grid_dims (grid_dims_arr[ 0 ], grid_dims_arr[ 1 ], grid_dims_arr[ 2 ] );
74115
75116 // Shared memory size: one float per warp
76- const int num_warps = (ne00 + WARP_SIZE - 1 ) / WARP_SIZE;
117+ const int num_warps = (ne_dim + WARP_SIZE - 1 ) / WARP_SIZE;
77118 const size_t shmem_size = num_warps * sizeof (float );
78119
79120 cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>> (
80121 src, dst,
81122 ne00, ne01, ne02, ne03,
82123 nb00, nb01, nb02, nb03,
83- nb0, nb1, nb2, nb3
124+ nb0, nb1, nb2, nb3,
125+ dim
84126 );
85127}
86128
87129void ggml_cuda_op_cumsum (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
88130 const ggml_tensor * src0 = dst->src [0 ];
89131 cudaStream_t stream = ctx.stream ();
90132
133+ const int dim = ggml_get_op_params_i32 (dst, 0 );
134+
91135 GGML_ASSERT (src0->type == dst->type );
92136 switch (src0->type ) {
93137 case GGML_TYPE_F32:
@@ -97,6 +141,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
97141 src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
98142 src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
99143 dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ],
144+ dim,
100145 stream
101146 );
102147 } break ;
@@ -107,6 +152,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
107152 src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
108153 src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
109154 dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ],
155+ dim,
110156 stream
111157 );
112158 } break ;
@@ -117,6 +163,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
117163 src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
118164 src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
119165 dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ],
166+ dim,
120167 stream
121168 );
122169 } break ;
0 commit comments