Skip to content

Commit 614dee0

Browse files
committed
Introduction of CUDA Programmatic Dependent Launch to Llama.cpp
See #15479
1 parent 5aa1105 commit 614dee0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+668
-0
lines changed

ggml/src/ggml-cuda/acc.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
44
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
55
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
6+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
7+
cudaGridDependencySynchronize();
8+
#endif
69
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
710

811
if (i >= ne) {
@@ -25,6 +28,9 @@ static __global__ void acc_f32(const float * x, const float * y, float * dst, co
2528
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
2629
}
2730
dst[i] = val;
31+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
32+
cudaTriggerProgrammaticLaunchCompletion();
33+
#endif
2834
}
2935

3036
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,

ggml/src/ggml-cuda/arange.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
#include "arange.cuh"
22

33
static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
4+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
5+
cudaGridDependencySynchronize();
6+
#endif
47
// blockIDx.x: idx of ne0 / BLOCK_SIZE
58
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
69
if (nidx >= ne0) {
710
return;
811
}
912
dst[nidx] = start + step * nidx;
13+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
14+
cudaTriggerProgrammaticLaunchCompletion();
15+
#endif
1016
}
1117

1218
static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {

ggml/src/ggml-cuda/argmax.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include "sum.cuh"
77

88
static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
9+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
10+
cudaGridDependencySynchronize();
11+
#endif
912
const int64_t row = blockIdx.x;
1013

1114
float maxval = -FLT_MAX;
@@ -64,6 +67,9 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
6467
if (warp_id == 0 && lane_id == 0) {
6568
dst[row] = argmax;
6669
}
70+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
71+
cudaTriggerProgrammaticLaunchCompletion();
72+
#endif
6773
}
6874

6975
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-cuda/argsort.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
99

1010
template<ggml_sort_order order>
1111
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
12+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
13+
cudaGridDependencySynchronize();
14+
#endif
1215
// bitonic sort
1316
int col = threadIdx.x;
1417
int row = blockIdx.y;
@@ -55,6 +58,9 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
5558
if (col < ncols) {
5659
dst[row * ncols + col] = dst_row[col];
5760
}
61+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
62+
cudaTriggerProgrammaticLaunchCompletion();
63+
#endif
5864
}
5965

6066
static int next_power_of_2(int x) {

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
2929
/*int s0, */ int s1, int s2, int s3,
3030
/*int s00,*/ int s01, int s02, int s03,
3131
/*int s10,*/ int s11, int s12, int s13) {
32+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
33+
cudaGridDependencySynchronize();
34+
#endif
3235
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
3336
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
3437
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@@ -54,6 +57,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
5457
const int i10 = i0 % ne10;
5558
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
5659
}
60+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
61+
cudaTriggerProgrammaticLaunchCompletion();
62+
#endif
5763
}
5864

5965
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
@@ -63,6 +69,9 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
6369
/*int s0, */ int s1, int s2, int s3,
6470
/*int s00,*/ int s01, int s02, int s03,
6571
/*int s10,*/ int s11, int s12, int s13) {
72+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
73+
cudaGridDependencySynchronize();
74+
#endif
6675

6776
const int i = blockDim.x*blockIdx.x + threadIdx.x;
6877

@@ -89,13 +98,19 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
8998

9099
const int i10 = i0 % ne10;
91100
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
101+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
102+
cudaTriggerProgrammaticLaunchCompletion();
103+
#endif
92104
}
93105

94106
template <typename T>
95107
static __global__ void k_repeat_back(
96108
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
97109
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
98110
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
111+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
112+
cudaGridDependencySynchronize();
113+
#endif
99114

100115
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
101116
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
@@ -118,6 +133,9 @@ static __global__ void k_repeat_back(
118133
}
119134
}
120135
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
136+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
137+
cudaTriggerProgrammaticLaunchCompletion();
138+
#endif
121139
}
122140

123141
template<float (*bin_op)(const float, const float)>

ggml/src/ggml-cuda/clamp.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@ static __device__ __forceinline__ float op_clamp(float x, float min, float max)
66

77
template <class T>
88
static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {
9+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
10+
cudaGridDependencySynchronize();
11+
#endif
912
const int i = blockDim.x*blockIdx.x + threadIdx.x;
1013

1114
if (i >= k) {
1215
return;
1316
}
1417

1518
dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);
19+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
20+
cudaTriggerProgrammaticLaunchCompletion();
21+
#endif
1622
}
1723

1824
template <class T>

ggml/src/ggml-cuda/common.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#define GGML_CUDA_CC_TURING 750
4848
#define GGML_CUDA_CC_AMPERE 800
4949
#define GGML_CUDA_CC_ADA_LOVELACE 890
50+
#define GGML_CUDA_CC_HOPPER 900
5051
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
5152
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
5253
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@@ -414,6 +415,9 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
414415
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
415416
template<bool norm>
416417
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
418+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
419+
cudaGridDependencySynchronize();
420+
#endif
417421
const int row = blockIdx.x;
418422
const int col = threadIdx.x;
419423

@@ -425,10 +429,16 @@ static __global__ void reduce_rows_f32(const float * x, float * dst, const int n
425429
sum = warp_reduce_sum(sum);
426430

427431
if (col != 0) {
432+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
433+
cudaTriggerProgrammaticLaunchCompletion();
434+
#endif
428435
return;
429436
}
430437

431438
dst[row] = norm ? sum / ncols : sum;
439+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
440+
cudaTriggerProgrammaticLaunchCompletion();
441+
#endif
432442
}
433443

434444
template<int width = WARP_SIZE>
@@ -832,6 +842,9 @@ struct ggml_cuda_graph {
832842
// Index to allow each cpy kernel to be aware of it's position within the graph
833843
// relative to other cpy nodes.
834844
int graph_cpynode_index = -1;
845+
std::vector<cudaGraphNode_t> graph_nodes;
846+
std::vector<cudaGraphNode_t> graph_dependencies;
847+
bool allow_pdl = true; // whether Programmatic Dependent Launch can be used within CUDA graph
835848
#endif
836849
};
837850

ggml/src/ggml-cuda/concat.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
// contiguous kernels
44
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
5+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
6+
cudaGridDependencySynchronize();
7+
#endif
58
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
69
if (nidx >= ne0) {
710
return;
@@ -25,9 +28,15 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float *
2528
blockIdx.z * (ne0 - ne00) * gridDim.y;
2629
dst[offset_dst] = y[offset_src];
2730
}
31+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
32+
cudaTriggerProgrammaticLaunchCompletion();
33+
#endif
2834
}
2935

3036
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
37+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
38+
cudaGridDependencySynchronize();
39+
#endif
3140
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
3241
if (nidx >= ne0) {
3342
return;
@@ -51,9 +60,15 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
5160
blockIdx.z * ne0 * (gridDim.y - ne01);
5261
dst[offset_dst] = y[offset_src];
5362
}
63+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
64+
cudaTriggerProgrammaticLaunchCompletion();
65+
#endif
5466
}
5567

5668
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
69+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
70+
cudaGridDependencySynchronize();
71+
#endif
5772
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
5873
if (nidx >= ne0) {
5974
return;
@@ -77,6 +92,9 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
7792
(blockIdx.z - ne02) * ne0 * gridDim.y;
7893
dst[offset_dst] = y[offset_src];
7994
}
95+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
96+
cudaTriggerProgrammaticLaunchCompletion();
97+
#endif
8098
}
8199

82100
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
@@ -124,6 +142,9 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
124142
uint64_t nb1,
125143
uint64_t nb2,
126144
uint64_t nb3){
145+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
146+
cudaGridDependencySynchronize();
147+
#endif
127148
static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]");
128149

129150
const int64_t i3 = blockIdx.z;
@@ -151,6 +172,9 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
151172

152173
*y = *x;
153174
}
175+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
176+
cudaTriggerProgrammaticLaunchCompletion();
177+
#endif
154178
}
155179

156180

ggml/src/ggml-cuda/conv-transpose-1d.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ static __global__ void conv_transpose_1d_kernel(
66
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
77
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
88
const float * src0, const float * src1, float * dst) {
9+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
10+
cudaGridDependencySynchronize();
11+
#endif
912
int global_index = threadIdx.x + blockIdx.x * blockDim.x;
1013
if (global_index >= output_size) {
1114
return;
@@ -38,6 +41,9 @@ static __global__ void conv_transpose_1d_kernel(
3841
GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3);
3942
GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1);
4043
GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2);
44+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
45+
cudaTriggerProgrammaticLaunchCompletion();
46+
#endif
4147
}
4248

4349
static void conv_transpose_1d_f32_f32_cuda(

ggml/src/ggml-cuda/conv2d-dw.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,16 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
8484
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
8585
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
8686
const int channels, const int batches) {
87+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
88+
cudaGridDependencySynchronize();
89+
#endif
8790
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
8891
const int total_elements = batches * channels * out_h * out_w;
8992

9093
if (global_idx >= total_elements) {
94+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
95+
cudaTriggerProgrammaticLaunchCompletion();
96+
#endif
9197
return;
9298
}
9399

@@ -114,6 +120,9 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
114120
}
115121

116122
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
123+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
124+
cudaTriggerProgrammaticLaunchCompletion();
125+
#endif
117126
}
118127

119128
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

0 commit comments

Comments
 (0)