Skip to content

Commit 3b4055e

Browse files
committed
feat(ggml-cuda): Support arbitrary dims and non-cont in cumsum
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent ee13af1 commit 3b4055e

File tree

1 file changed

+67
-20
lines changed

1 file changed

+67
-20
lines changed

ggml/src/ggml-cuda/cumsum.cu

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "cumsum.cuh"
22

3-
// Kernel to compute cumulative sum along the innermost dimension (ne[0])
4-
// Each block processes one row (ne[0] elements)
3+
// Kernel to compute cumulative sum along an arbitrary dimension
4+
// Each block processes one position in the non-cumsum dimensions
55
// Algorithm matches Metal implementation:
66
// 1. Each warp computes prefix sum within itself
77
// 2. Last thread of each warp stores result in shared memory
@@ -13,36 +13,60 @@ static __global__ void cumsum_kernel(
1313
const T * src, T * dst,
1414
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
1515
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
16-
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
16+
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
17+
const int dim) {
1718

1819
// Shared memory to store warp sums (always use float for accumulation)
1920
extern __shared__ float shmem[];
2021

21-
const int64_t i3 = blockIdx.z;
22-
const int64_t i2 = blockIdx.y;
23-
const int64_t i1 = blockIdx.x;
22+
// Map block indices to actual tensor dimensions
23+
// blockIdx.x, blockIdx.y, blockIdx.z represent the 3 non-cumsum dimensions
24+
// threadIdx.x represents position in the cumsum dimension
25+
int64_t grid_indices[3] = {blockIdx.x, blockIdx.y, blockIdx.z};
26+
int64_t i_vals[4];
27+
28+
int grid_idx = 0;
29+
for (int d = 0; d < 4; ++d) {
30+
if (d == dim) {
31+
i_vals[d] = 0; // Will be set in the loop below
32+
} else {
33+
i_vals[d] = grid_indices[grid_idx++];
34+
}
35+
}
36+
37+
const int64_t i0 = i_vals[0];
38+
const int64_t i1 = i_vals[1];
39+
const int64_t i2 = i_vals[2];
40+
const int64_t i3 = i_vals[3];
2441

25-
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
42+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01 || i0 >= ne00) {
2643
return;
2744
}
2845

29-
const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03);
30-
T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
46+
const int64_t ne_dim = (dim == 0) ? ne00 : (dim == 1) ? ne01 : (dim == 2) ? ne02 : ne03;
47+
const int64_t nb_dim_src = (dim == 0) ? nb00 : (dim == 1) ? nb01 : (dim == 2) ? nb02 : nb03;
48+
const int64_t nb_dim_dst = (dim == 0) ? nb0 : (dim == 1) ? nb1 : (dim == 2) ? nb2 : nb3;
3149

3250
const int tid = threadIdx.x;
3351
const int lane_id = tid % WARP_SIZE;
3452

3553
// Phase 1: Each thread processes elements at stride blockDim.x
3654
// Compute warp-level prefix sums
37-
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
55+
for (int64_t i_dim = tid; i_dim < ne_dim; i_dim += blockDim.x) {
56+
const int64_t offset_src = i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03 + i_dim*nb_dim_src;
57+
const int64_t offset_dst = i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 + i_dim*nb_dim_dst;
58+
59+
const T * src_ptr = (const T *) ((const char *) src + offset_src);
60+
T * dst_ptr = (T *) (( char *) dst + offset_dst);
61+
3862
// Load value and compute prefix sum within warp
39-
float val = static_cast<float>(src_row[i0]);
63+
float val = static_cast<float>(src_ptr[0]);
4064
val = warp_prefix_inclusive_sum(val);
41-
dst_row[i0] = static_cast<T>(val);
65+
dst_ptr[0] = static_cast<T>(val);
4266

4367
// Last thread of warp stores its sum to shared memory at position based on data index
44-
if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) {
45-
const int shmem_idx = i0 / WARP_SIZE;
68+
if (lane_id == WARP_SIZE - 1 || i_dim == ne_dim - 1) {
69+
const int shmem_idx = i_dim / WARP_SIZE;
4670
shmem[shmem_idx] = val;
4771
}
4872
}
@@ -51,13 +75,16 @@ static __global__ void cumsum_kernel(
5175
__syncthreads();
5276

5377
// Phase 2: Add the sum of all preceding warp groups to each element
54-
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
55-
const int shmem_idx = i0 / WARP_SIZE;
78+
for (int64_t i_dim = tid; i_dim < ne_dim; i_dim += blockDim.x) {
79+
const int64_t offset_dst = i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 + i_dim*nb_dim_dst;
80+
T * dst_ptr = (T *) ((char *) dst + offset_dst);
81+
82+
const int shmem_idx = i_dim / WARP_SIZE;
5683
float sum = 0.0f;
5784
for (int j = 0; j < shmem_idx; ++j) {
5885
sum += shmem[j];
5986
}
60-
dst_row[i0] = static_cast<T>(static_cast<float>(dst_row[i0]) + sum);
87+
dst_ptr[0] = static_cast<T>(static_cast<float>(dst_ptr[0]) + sum);
6188
}
6289
}
6390

@@ -67,27 +94,44 @@ static void cumsum_cuda(
6794
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
6895
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
6996
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
97+
const int dim,
7098
cudaStream_t stream) {
7199

100+
// Dimension being accumulated
101+
const int64_t ne_dims[4] = {ne00, ne01, ne02, ne03};
102+
const int64_t ne_dim = ne_dims[dim];
103+
104+
// Grid dimensions: the GGML_MAX_DIMS-1 non-cumsum dimensions
105+
int64_t grid_dims_arr[GGML_MAX_DIMS - 1];
106+
int grid_idx = 0;
107+
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
108+
if (d != dim) {
109+
grid_dims_arr[grid_idx++] = ne_dims[d];
110+
}
111+
}
112+
72113
dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1);
73-
dim3 grid_dims(ne01, ne02, ne03);
114+
dim3 grid_dims(grid_dims_arr[0], grid_dims_arr[1], grid_dims_arr[2]);
74115

75116
// Shared memory size: one float per warp
76-
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
117+
const int num_warps = (ne_dim + WARP_SIZE - 1) / WARP_SIZE;
77118
const size_t shmem_size = num_warps * sizeof(float);
78119

79120
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
80121
src, dst,
81122
ne00, ne01, ne02, ne03,
82123
nb00, nb01, nb02, nb03,
83-
nb0, nb1, nb2, nb3
124+
nb0, nb1, nb2, nb3,
125+
dim
84126
);
85127
}
86128

87129
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
88130
const ggml_tensor * src0 = dst->src[0];
89131
cudaStream_t stream = ctx.stream();
90132

133+
const int dim = ggml_get_op_params_i32(dst, 0);
134+
91135
GGML_ASSERT(src0->type == dst->type);
92136
switch(src0->type) {
93137
case GGML_TYPE_F32:
@@ -97,6 +141,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
97141
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
98142
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
99143
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
144+
dim,
100145
stream
101146
);
102147
} break;
@@ -107,6 +152,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
107152
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
108153
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
109154
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
155+
dim,
110156
stream
111157
);
112158
} break;
@@ -117,6 +163,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
117163
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
118164
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
119165
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
166+
dim,
120167
stream
121168
);
122169
} break;

0 commit comments

Comments
 (0)