Skip to content

Commit 656584f

Browse files
authored
add softcap fusion
1 parent bf78f54 commit 656584f

File tree

4 files changed

+140
-23
lines changed

4 files changed

+140
-23
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "ggml-cuda/quantize.cuh"
3333
#include "ggml-cuda/rope.cuh"
3434
#include "ggml-cuda/scale.cuh"
35+
#include "ggml-cuda/softcap.cuh"
3536
#include "ggml-cuda/softmax.cuh"
3637
#include "ggml-cuda/ssm-conv.cuh"
3738
#include "ggml-cuda/ssm-scan.cuh"
@@ -2766,34 +2767,59 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
27662767
}
27672768
#endif
27682769

2769-
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2770+
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
27702771
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
27712772
return false;
27722773
}
27732774

2774-
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2775-
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2776-
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2775+
switch (ops.size()) {
2776+
case 2:
2777+
if (ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2778+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2779+
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
27772780

2778-
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2779-
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2781+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2782+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
27802783

2781-
//rms norm only supports F32
2782-
if (mul->src[0]->type != GGML_TYPE_F32 ||
2783-
mul->src[1]->type != GGML_TYPE_F32 ||
2784-
mul->type != GGML_TYPE_F32) {
2785-
return false;
2786-
}
2784+
//rms norm only supports F32
2785+
if (mul->src[0]->type != GGML_TYPE_F32 ||
2786+
mul->src[1]->type != GGML_TYPE_F32 ||
2787+
mul->type != GGML_TYPE_F32) {
2788+
return false;
2789+
}
27872790

2788-
//if rms norm is the B operand, then we don't handle broadcast
2789-
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2790-
return false;
2791-
}
2791+
//if rms norm is the B operand, then we don't handle broadcast
2792+
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2793+
return false;
2794+
}
2795+
2796+
//rms_norm kernel assumes contigous rows
2797+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2798+
return false;
2799+
}
2800+
}
2801+
break;
2802+
case 3:
2803+
if (ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
2804+
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
2805+
const ggml_tensor *scale = cgraph->nodes[node_idx];
2806+
const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
2807+
const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
2808+
2809+
GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
2810+
GGML_ASSERT(scale->type == GGML_TYPE_F32);
2811+
2812+
if (tanh->src[0] != scale || scale2->src[0] != tanh) {
2813+
return false;
2814+
}
27922815

2793-
//rms_norm kernel assumes contigous rows
2794-
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2816+
if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
2817+
return false;
2818+
}
2819+
}
2820+
break;
2821+
default:
27952822
return false;
2796-
}
27972823
}
27982824

27992825
return true;
@@ -2817,10 +2843,27 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
28172843
}
28182844

28192845
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
2820-
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2821-
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
2822-
i++;
2823-
continue;
2846+
if (!disable_fusion) {
2847+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
2848+
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
2849+
i++;
2850+
continue;
2851+
}
2852+
2853+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
2854+
ggml_tensor * src0 = node->src[0];
2855+
float scale = ggml_get_op_params_f32(node, 0);
2856+
2857+
i += 2; node = cgraph->nodes[i];
2858+
float softcap = ggml_get_op_params_f32(node, 0);
2859+
2860+
ggml_set_op_params_f32(node, 0, scale);
2861+
ggml_set_op_params_f32(node, 1, softcap);
2862+
node->src[0] = src0;
2863+
2864+
ggml_cuda_op_softcap(*cuda_ctx, node);
2865+
continue;
2866+
}
28242867
}
28252868
#ifndef NDEBUG
28262869
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));

ggml/src/ggml-cuda/softcap.cu

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "softcap.cuh"
2+
3+
static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
4+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
5+
6+
if (i >= k) {
7+
return;
8+
}
9+
10+
dst[i] = tanhf(scale * x[i]) * softcap;
11+
}
12+
13+
static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
14+
const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
15+
softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
16+
}
17+
18+
void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
19+
const ggml_tensor * src0 = dst->src[0];
20+
const float * src0_d = (const float *)src0->data;
21+
float * dst_d = (float *)dst->data;
22+
cudaStream_t stream = ctx.stream();
23+
24+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
25+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
26+
27+
float scale;
28+
float softcap;
29+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
30+
memcpy(&softcap, (float *) dst->op_params + 1, sizeof(float));
31+
32+
softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);
33+
}

ggml/src/ggml-cuda/softcap.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_SOFTCAP_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,6 +2514,41 @@ struct test_scale : public test_case {
25142514
}
25152515
};
25162516

2517+
// GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
2518+
struct test_softcap : public test_case {
2519+
const ggml_type type;
2520+
const std::array<int64_t, 4> ne;
2521+
float softcap;
2522+
2523+
std::string op_desc(ggml_tensor * t) override {
2524+
GGML_UNUSED(t);
2525+
return "SOFTCAP";
2526+
}
2527+
2528+
bool run_whole_graph() override { return true; }
2529+
2530+
std::string vars() override {
2531+
return VARS_TO_STR3(type, ne, softcap);
2532+
}
2533+
2534+
test_softcap(ggml_type type = GGML_TYPE_F32,
2535+
std::array<int64_t, 4> ne = {10, 10, 10, 10},
2536+
float softcap = 30.0f)
2537+
: type(type), ne(ne), softcap(softcap) {}
2538+
2539+
ggml_tensor * build_graph(ggml_context * ctx) override {
2540+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2541+
2542+
ggml_set_param(a);
2543+
ggml_set_name(a, "a");
2544+
2545+
ggml_tensor * out = ggml_scale(ctx, ggml_tanh(ctx, ggml_scale(ctx, a, 1.0f / softcap)), softcap);
2546+
ggml_set_name(out, "out");
2547+
2548+
return out;
2549+
}
2550+
};
2551+
25172552
// GGML_OP_SILU_BACK
25182553
struct test_silu_back : public test_case {
25192554
const ggml_type type;
@@ -5390,6 +5425,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53905425
test_cases.emplace_back(new test_add1());
53915426
test_cases.emplace_back(new test_scale());
53925427
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
5428+
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
53935429
test_cases.emplace_back(new test_silu_back());
53945430

53955431
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {

0 commit comments

Comments
 (0)