Skip to content

Commit 30e3080

Browse files
committed
Reapply "cuda : add softcap fusion (ggml-org#14907)"
This reverts commit 9e3cec2.
1 parent ceeddd4 commit 30e3080

File tree

3 files changed

+82
-6
lines changed

3 files changed

+82
-6
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ bool g_mul_mat_q = true;
3737
#include "ggml-cuda/rope.cuh"
3838
#include "ggml-cuda/roll.cuh"
3939
#include "ggml-cuda/scale.cuh"
40+
#include "ggml-cuda/softcap.cuh"
4041
#include "ggml-cuda/softmax.cuh"
4142
#include "ggml-cuda/ssm-conv.cuh"
4243
#include "ggml-cuda/ssm-scan.cuh"
@@ -3029,7 +3030,12 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
30293030
}
30303031
#endif
30313032

3032-
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
3033+
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
3034+
#ifndef NDEBUG
3035+
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
3036+
GGML_ASSERT(unary_ops.size() == num_unary);
3037+
#endif
3038+
30333039
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
30343040
return false;
30353041
}
@@ -3057,9 +3063,32 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30573063
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
30583064
return false;
30593065
}
3066+
3067+
return true;
30603068
}
30613069

3062-
return true;
3070+
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
3071+
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
3072+
const ggml_tensor *scale = cgraph->nodes[node_idx];
3073+
const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
3074+
const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
3075+
3076+
GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
3077+
GGML_ASSERT(scale->type == GGML_TYPE_F32);
3078+
3079+
if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
3080+
return false;
3081+
}
3082+
3083+
// Check for bias
3084+
if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
3085+
return false;
3086+
}
3087+
3088+
return true;
3089+
}
3090+
3091+
return false;
30633092
}
30643093

30653094
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
@@ -3080,10 +3109,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
30803109
}
30813110

30823111
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
3083-
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
3084-
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3085-
i++;
3086-
continue;
3112+
if (!disable_fusion) {
3113+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
3114+
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3115+
i++;
3116+
continue;
3117+
}
3118+
3119+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3120+
i += 2;
3121+
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
3122+
continue;
3123+
}
30873124
}
30883125
#ifndef NDEBUG
30893126
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));

ggml/src/ggml-cuda/softcap.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "softcap.cuh"
2+
3+
static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
4+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
5+
6+
if (i >= k) {
7+
return;
8+
}
9+
10+
dst[i] = tanhf(scale * x[i]) * softcap;
11+
}
12+
13+
static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
14+
const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
15+
softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
16+
}
17+
18+
// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
19+
void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) {
20+
const ggml_tensor * src0 = src->src[0];
21+
const float * src0_d = (const float *)src0->data;
22+
float * dst_d = (float *)dst->data;
23+
cudaStream_t stream = ctx.stream();
24+
25+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
26+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
27+
28+
float scale;
29+
float softcap;
30+
memcpy(&scale, (float *) src->op_params + 0, sizeof(float));
31+
memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float));
32+
33+
softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);
34+
}

ggml/src/ggml-cuda/softcap.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_SOFTCAP_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src);

0 commit comments

Comments
 (0)