Skip to content

Commit 96fe9ba

Browse files
Add support for CUMSUM and TRI for CUDA. (#17584)
* Add support for CUMSUM and TRI for CUDA. * Minor optimizations. * Correct warp_prefix_inclusive_sum in float2 variant to return float2 * Optimize TRI * Whitespace * Fix strides. * Implement double loop * Whitespace * Fix HIP compilation bugs * Optimizations + big case performance tests * Implement using CUB with fallback to custom kernel * Remove error message. * Fixes from code review * Comment out CPU-unsupported F16/BF16 cases to fix CI * Fine, you win :P * Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS * Vary warp-size based on physical warp size * Add GGML_UNUSED_VARS in tri as well * Use constexpr and call prefix_inclusive with warp_size template param * Update ggml/src/ggml-cuda/cumsum.cu Co-authored-by: Johannes Gäßler <[email protected]> * Apply suggestions from code review Co-authored-by: Johannes Gäßler <[email protected]> * Change to tid % warp_size * Fix strides; hardcode mask; add ggml_lane_mask_t * Missing renames, remove unused get_warp_mask(), explicit calls to ggml_cuda_info() * Too hasty... --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent bde188d commit 96fe9ba

File tree

7 files changed

+448
-0
lines changed

7 files changed

+448
-0
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,53 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
463463
return x;
464464
}
465465

466+
template<typename T, int width = WARP_SIZE>
467+
static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
468+
const int lane_id = threadIdx.x % width;
469+
#pragma unroll
470+
for (int offset = 1; offset < width; offset <<= 1) {
471+
const T t = __shfl_up_sync(0xffffffff, x, offset, width);
472+
if (lane_id >= offset) {
473+
x += t;
474+
}
475+
}
476+
return x;
477+
}
478+
479+
template<int width = WARP_SIZE>
480+
static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
481+
const int lane_id = threadIdx.x % width;
482+
#pragma unroll
483+
for (int offset = 1; offset < width; offset <<= 1) {
484+
const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
485+
const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
486+
if (lane_id >= offset) {
487+
a.x += t_x;
488+
a.y += t_y;
489+
}
490+
}
491+
return a;
492+
}
493+
494+
template<int width = WARP_SIZE>
495+
static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
496+
#ifdef FP16_AVAILABLE
497+
const int lane_id = threadIdx.x % width;
498+
#pragma unroll
499+
for (int offset = 1; offset < width; offset <<= 1) {
500+
const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
501+
if (lane_id >= offset) {
502+
a = __hadd2(a, t);
503+
}
504+
}
505+
return a;
506+
507+
#else
508+
NO_DEVICE_CODE;
509+
return a;
510+
#endif // FP16_AVAILABLE
511+
}
512+
466513
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
467514
#ifdef FP16_AVAILABLE
468515

ggml/src/ggml-cuda/cumsum.cu

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
#include <algorithm>
2+
#include "cumsum.cuh"
3+
#include "convert.cuh"
4+
#include "ggml-cuda/common.cuh"
5+
#include "ggml.h"
6+
7+
#ifdef GGML_CUDA_USE_CUB
8+
# include <cub/device/device_scan.cuh>
9+
#endif // GGML_CUDA_USE_CUB
10+
11+
template<typename T, int BLOCK_SIZE>
12+
static __global__ void cumsum_cub_kernel(
13+
const T * __restrict__ src,
14+
T * __restrict__ dst,
15+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
16+
const int64_t s01, const int64_t s02, const int64_t s03,
17+
const int64_t s1, const int64_t s2, const int64_t s3) {
18+
#ifdef GGML_CUDA_USE_CUB
19+
using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
20+
21+
__shared__ typename BlockScan::TempStorage temp_storage;
22+
__shared__ T block_carry; // carry from previous tile
23+
24+
const int tid = threadIdx.x;
25+
26+
const int64_t i1 = blockIdx.x;
27+
const int64_t i2 = blockIdx.y;
28+
const int64_t i3 = blockIdx.z;
29+
30+
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
31+
return;
32+
}
33+
34+
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
35+
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
36+
37+
if (tid == 0) {
38+
block_carry = 0;
39+
}
40+
__syncthreads();
41+
42+
for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
43+
int64_t idx = start + tid;
44+
T x = (idx < ne00) ? src_row[idx] : T(0);
45+
46+
T inclusive;
47+
T block_total;
48+
BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
49+
50+
__syncthreads();
51+
52+
T final_val = inclusive + block_carry;
53+
54+
// store result
55+
if (idx < ne00) {
56+
dst_row[idx] = final_val;
57+
}
58+
59+
__syncthreads();
60+
61+
if (tid == 0) {
62+
block_carry += block_total;
63+
}
64+
65+
__syncthreads();
66+
}
67+
#else
68+
NO_DEVICE_CODE;
69+
#endif // GGML_CUDA_USE_CUB
70+
}
71+
72+
// Fallback kernel implementation (original)
73+
template<typename T>
74+
static __global__ void cumsum_kernel(
75+
const T * src, T * dst,
76+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
77+
const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
78+
const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
79+
80+
GGML_UNUSED_VARS(s00, s0);
81+
82+
const int tid = threadIdx.x;
83+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
84+
const int lane = tid % warp_size;
85+
const int warp = tid / warp_size;
86+
const int warps_per_block = blockDim.x / warp_size;
87+
88+
extern __shared__ float smem[];
89+
float * s_vals = smem;
90+
float * s_warp_sums = smem + blockDim.x;
91+
float * s_carry = smem + blockDim.x + warps_per_block;
92+
float * s_chunk_total = s_carry + 1;
93+
94+
// Initialize carry
95+
if (tid == 0) {
96+
*s_carry = 0.0f;
97+
}
98+
__syncthreads();
99+
100+
const int64_t i3 = blockIdx.z;
101+
const int64_t i2 = blockIdx.y;
102+
const int64_t i1 = blockIdx.x;
103+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
104+
return;
105+
}
106+
107+
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
108+
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
109+
110+
for (int64_t start = 0; start < ne00; start += blockDim.x) {
111+
int64_t idx = start + tid;
112+
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
113+
114+
// 1. Warp inclusive scan
115+
val = warp_prefix_inclusive_sum<T, warp_size>(val);
116+
s_vals[tid] = val;
117+
118+
// Store warp total
119+
if (lane == warp_size - 1) {
120+
s_warp_sums[warp] = val;
121+
}
122+
__syncthreads();
123+
124+
// 2. Exclusive scan of warp sums (warp 0 only)
125+
if (warp == 0) {
126+
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
127+
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
128+
if (tid < warps_per_block) {
129+
s_warp_sums[tid] = inc - w; // exclusive sum
130+
}
131+
if (tid == warps_per_block - 1) {
132+
*s_chunk_total = inc; // total sum of this chunk
133+
}
134+
}
135+
__syncthreads();
136+
137+
float carry = *s_carry;
138+
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
139+
if (idx < ne00) {
140+
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
141+
}
142+
__syncthreads();
143+
144+
// Update carry for next chunk
145+
if (tid == 0) {
146+
*s_carry += *s_chunk_total;
147+
}
148+
__syncthreads();
149+
}
150+
}
151+
152+
template<typename T>
153+
static void cumsum_cuda(
154+
const T * src, T * dst,
155+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
156+
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
157+
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
158+
cudaStream_t stream) {
159+
160+
const size_t type_size = sizeof(T);
161+
bool use_cub = false;
162+
#ifdef GGML_CUDA_USE_CUB
163+
// Check if we can use CUB (data must be contiguous along innermost dimension)
164+
const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
165+
166+
if (is_contiguous) {
167+
use_cub = true;
168+
}
169+
#endif // GGML_CUDA_USE_CUB
170+
dim3 grid_dims(ne01, ne02, ne03);
171+
const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
172+
const int warp_size = info.warp_size;
173+
const int num_warps = (ne00 + warp_size - 1) / warp_size;
174+
int block_size = num_warps * warp_size;
175+
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
176+
dim3 block_dims(block_size, 1, 1);
177+
const int warps_per_block = block_size / warp_size;
178+
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
179+
180+
if (use_cub) {
181+
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
182+
src, dst,
183+
ne00, ne01, ne02, ne03,
184+
nb01 / type_size, nb02 / type_size, nb03 / type_size,
185+
nb1 / type_size, nb2 / type_size, nb3 / type_size
186+
);
187+
} else {
188+
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
189+
src, dst,
190+
ne00, ne01, ne02, ne03,
191+
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
192+
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
193+
);
194+
}
195+
}
196+
197+
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
198+
const ggml_tensor * src0 = dst->src[0];
199+
cudaStream_t stream = ctx.stream();
200+
201+
GGML_ASSERT(src0->type == dst->type);
202+
switch(src0->type) {
203+
case GGML_TYPE_F32:
204+
{
205+
cumsum_cuda(
206+
(const float *)src0->data, (float *)dst->data,
207+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
208+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
209+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
210+
stream
211+
);
212+
} break;
213+
// We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
214+
/*case GGML_TYPE_F16:
215+
{
216+
cumsum_cuda(
217+
(const half *)src0->data, (half *)dst->data,
218+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
219+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
220+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
221+
stream
222+
);
223+
} break;
224+
case GGML_TYPE_BF16:
225+
{
226+
cumsum_cuda(
227+
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
228+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
229+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
230+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
231+
stream
232+
);
233+
} break;*/
234+
default:
235+
GGML_ABORT("fatal error");
236+
}
237+
}

ggml/src/ggml-cuda/cumsum.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_CUMSUM_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
#include "ggml-cuda/set-rows.cuh"
5555
#include "ggml-cuda/pad_reflect_1d.cuh"
5656
#include "ggml-cuda/solve_tri.cuh"
57+
#include "ggml-cuda/tri.cuh"
58+
#include "ggml-cuda/cumsum.cuh"
5759
#include "ggml.h"
5860

5961
#include <algorithm>
@@ -2701,6 +2703,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
27012703
case GGML_OP_CROSS_ENTROPY_LOSS:
27022704
ggml_cuda_cross_entropy_loss(ctx, dst);
27032705
break;
2706+
case GGML_OP_CUMSUM:
2707+
ggml_cuda_op_cumsum(ctx, dst);
2708+
break;
2709+
case GGML_OP_TRI:
2710+
ggml_cuda_op_tri(ctx, dst);
2711+
break;
27042712
case GGML_OP_RWKV_WKV6:
27052713
ggml_cuda_op_rwkv_wkv6(ctx, dst);
27062714
break;
@@ -4609,6 +4617,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
46094617
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
46104618
case GGML_OP_OPT_STEP_ADAMW:
46114619
case GGML_OP_OPT_STEP_SGD:
4620+
case GGML_OP_CUMSUM:
4621+
case GGML_OP_TRI:
46124622
return true;
46134623
case GGML_OP_SOLVE_TRI:
46144624
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;

0 commit comments

Comments
 (0)