Skip to content

Commit d138a03

Browse files
committed
Add support for CUMSUM and TRI for CUDA.
1 parent d82b7a7 commit d138a03

File tree

7 files changed

+306
-0
lines changed

7 files changed

+306
-0
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,56 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
461461
return x;
462462
}
463463

464+
template<typename T, int width = WARP_SIZE>
465+
static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
466+
const int lane_id = threadIdx.x % width;
467+
const auto mask = __activemask();
468+
#pragma unroll
469+
for (int offset = 1; offset < width; offset <<= 1) {
470+
const T t = __shfl_up_sync(mask, x, offset, width);
471+
if (lane_id >= offset) {
472+
x += t;
473+
}
474+
}
475+
return x;
476+
}
477+
478+
template<int width = WARP_SIZE>
479+
static __device__ __forceinline__ float warp_prefix_inclusive_sum(float2 a) {
480+
const int lane_id = threadIdx.x % width;
481+
const auto mask = __activemask();
482+
#pragma unroll
483+
for (int offset = 1; offset < width; offset <<= 1) {
484+
const float t_x = __shfl_up_sync(mask, a.x, offset, width);
485+
const float t_y = __shfl_up_sync(mask, a.y, offset, width);
486+
if (lane_id >= offset) {
487+
a.x += t_x;
488+
a.y += t_y;
489+
}
490+
}
491+
return a;
492+
}
493+
494+
template<int width = WARP_SIZE>
495+
static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
496+
#ifdef FP16_AVAILABLE
497+
const int lane_id = threadIdx.x % width;
498+
const auto mask = __activemask();
499+
#pragma unroll
500+
for (int offset = 1; offset < width; offset <<= 1) {
501+
const t = __hadd2(__shfl_up_sync(mask, a, offset, width));
502+
if (lane_id >= offset) {
503+
a += t;
504+
}
505+
}
506+
return a;
507+
508+
#else
509+
NO_DEVICE_CODE;
510+
return a;
511+
#endif // FP16_AVAILABLE
512+
}
513+
464514
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
465515
#ifdef FP16_AVAILABLE
466516

ggml/src/ggml-cuda/cumsum.cu

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#include "cumsum.cuh"
2+
3+
// Kernel to compute cumulative sum along the innermost dimension (ne[0])
4+
// Each block processes one row (ne[0] elements)
5+
// Algorithm matches Metal implementation:
6+
// 1. Each warp computes prefix sum within itself
7+
// 2. Last thread of each warp stores result in shared memory
8+
// 3. All warps sync
9+
// 4. Each element adds the sum of all preceding warps
10+
11+
template<typename T>
12+
static __global__ void cumsum_kernel(
13+
const T * src, T * dst,
14+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
15+
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
16+
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
17+
18+
// Shared memory to store warp sums (always use float for accumulation)
19+
extern __shared__ float shmem[];
20+
21+
const int64_t i3 = blockIdx.z;
22+
const int64_t i2 = blockIdx.y;
23+
const int64_t i1 = blockIdx.x;
24+
25+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
26+
return;
27+
}
28+
29+
const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03);
30+
T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
31+
32+
const int tid = threadIdx.x;
33+
const int lane_id = tid % WARP_SIZE;
34+
35+
// Phase 1: Each thread processes elements at stride blockDim.x
36+
// Compute warp-level prefix sums
37+
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
38+
// Load value and compute prefix sum within warp
39+
float val = static_cast<float>(src_row[i0]);
40+
val = warp_prefix_inclusive_sum(val);
41+
dst_row[i0] = static_cast<T>(val);
42+
43+
// Last thread of warp stores its sum to shared memory at position based on data index
44+
if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) {
45+
const int shmem_idx = i0 / WARP_SIZE;
46+
shmem[shmem_idx] = val;
47+
}
48+
}
49+
50+
// Sync once after all warp prefix sums are computed
51+
__syncthreads();
52+
53+
// Phase 2: Add the sum of all preceding warp groups to each element
54+
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
55+
const int shmem_idx = i0 / WARP_SIZE;
56+
float sum = 0.0f;
57+
for (int j = 0; j < shmem_idx; ++j) {
58+
sum += shmem[j];
59+
}
60+
dst_row[i0] = static_cast<T>(static_cast<float>(dst_row[i0]) + sum);
61+
}
62+
}
63+
64+
template<typename T>
65+
static void cumsum_cuda(
66+
const T * src, T * dst,
67+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
68+
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
69+
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
70+
cudaStream_t stream) {
71+
72+
dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1);
73+
dim3 grid_dims(ne01, ne02, ne03);
74+
75+
// Shared memory size: one float per warp
76+
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
77+
const size_t shmem_size = num_warps * sizeof(float);
78+
79+
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
80+
src, dst,
81+
ne00, ne01, ne02, ne03,
82+
nb00, nb01, nb02, nb03,
83+
nb0, nb1, nb2, nb3
84+
);
85+
}
86+
87+
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
88+
const ggml_tensor * src0 = dst->src[0];
89+
cudaStream_t stream = ctx.stream();
90+
91+
GGML_ASSERT(src0->type == dst->type);
92+
switch(src0->type) {
93+
case GGML_TYPE_F32:
94+
{
95+
cumsum_cuda(
96+
(const float *)src0->data, (float *)dst->data,
97+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
98+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
99+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
100+
stream
101+
);
102+
} break;
103+
case GGML_TYPE_F16:
104+
{
105+
cumsum_cuda(
106+
(const half *)src0->data, (half *)dst->data,
107+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
108+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
109+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
110+
stream
111+
);
112+
} break;
113+
case GGML_TYPE_BF16:
114+
{
115+
cumsum_cuda(
116+
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
117+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
118+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
119+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
120+
stream
121+
);
122+
} break;
123+
default:
124+
GGML_ABORT("fatal error");
125+
}
126+
}

ggml/src/ggml-cuda/cumsum.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_CUMSUM_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
#include "ggml-cuda/set-rows.cuh"
5555
#include "ggml-cuda/pad_reflect_1d.cuh"
5656
#include "ggml-cuda/solve_tri.cuh"
57+
#include "ggml-cuda/tri.cuh"
58+
#include "ggml-cuda/cumsum.cuh"
5759
#include "ggml.h"
5860

5961
#include <algorithm>
@@ -2700,6 +2702,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
27002702
case GGML_OP_CROSS_ENTROPY_LOSS:
27012703
ggml_cuda_cross_entropy_loss(ctx, dst);
27022704
break;
2705+
case GGML_OP_CUMSUM:
2706+
ggml_cuda_op_cumsum(ctx, dst);
2707+
break;
2708+
case GGML_OP_TRI:
2709+
ggml_cuda_op_tri(ctx, dst);
2710+
break;
27032711
case GGML_OP_RWKV_WKV6:
27042712
ggml_cuda_op_rwkv_wkv6(ctx, dst);
27052713
break;
@@ -4262,6 +4270,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
42624270
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
42634271
case GGML_OP_OPT_STEP_ADAMW:
42644272
case GGML_OP_OPT_STEP_SGD:
4273+
case GGML_OP_CUMSUM:
4274+
case GGML_OP_TRI:
42654275
return true;
42664276
case GGML_OP_SOLVE_TRI:
42674277
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;

ggml/src/ggml-cuda/tri.cu

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#include "tri.cuh"
2+
#include "ggml.h"
3+
4+
// Triangle type comparison - determines which elements to keep
5+
__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) {
6+
switch (type) {
7+
case GGML_TRI_TYPE_LOWER: return i < r;
8+
case GGML_TRI_TYPE_LOWER_DIAG: return i <= r;
9+
case GGML_TRI_TYPE_UPPER: return i > r;
10+
case GGML_TRI_TYPE_UPPER_DIAG: return i >= r;
11+
default: return false;
12+
}
13+
}
14+
15+
template<typename T>
16+
static __global__ void tri_kernel(
17+
const T * src, T * dst,
18+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
19+
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
20+
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
21+
const ggml_tri_type ttype) {
22+
23+
const int64_t i3 = blockIdx.z;
24+
const int64_t i2 = blockIdx.y;
25+
const int64_t i1 = blockIdx.x;
26+
27+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
28+
return;
29+
}
30+
31+
const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03);
32+
T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
33+
34+
// Each thread processes elements at stride blockDim.x
35+
for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) {
36+
dst_row[i0] = tri_compare(i0, i1, ttype)
37+
? src_row[i0] : static_cast<T>(0.f);
38+
}
39+
}
40+
41+
template<typename T>
42+
static void tri_cuda(
43+
const T * src, T * dst,
44+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
45+
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
46+
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
47+
const ggml_tri_type ttype,
48+
cudaStream_t stream) {
49+
50+
dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
51+
dim3 grid_dims(ne01, ne02, ne03);
52+
53+
tri_kernel<<<grid_dims, block_dims, 0, stream>>>(
54+
src, dst,
55+
ne00, ne01, ne02, ne03,
56+
nb00, nb01, nb02, nb03,
57+
nb0, nb1, nb2, nb3,
58+
ttype
59+
);
60+
}
61+
62+
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
63+
const ggml_tensor * src0 = dst->src[0];
64+
cudaStream_t stream = ctx.stream();
65+
66+
const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
67+
68+
GGML_ASSERT(src0->type == dst->type);
69+
70+
switch(src0->type) {
71+
case GGML_TYPE_F32:
72+
{
73+
tri_cuda(
74+
(const float *)src0->data, (float *)dst->data,
75+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
76+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
77+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
78+
ttype, stream
79+
);
80+
} break;
81+
case GGML_TYPE_F16:
82+
{
83+
tri_cuda(
84+
(const half *)src0->data, (half *)dst->data,
85+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
86+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
87+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
88+
ttype, stream
89+
);
90+
} break;
91+
case GGML_TYPE_BF16:
92+
{
93+
tri_cuda(
94+
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
95+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
96+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
97+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
98+
ttype, stream
99+
);
100+
} break;
101+
default:
102+
GGML_ABORT("fatal error");
103+
}
104+
}

ggml/src/ggml-cuda/tri.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_TRI_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7938,6 +7938,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
79387938
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
79397939
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
79407940

7941+
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
7942+
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
7943+
7944+
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 }));
7945+
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 }));
7946+
79417947
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
79427948
for (ggml_type type_a : all_types) {
79437949
for (ggml_type type_b : {GGML_TYPE_F32}) {

0 commit comments

Comments
 (0)