Skip to content

Commit 78ac06b

Browse files
committed
Change softplus to ggml_softplus
1 parent 66f37c0 commit 78ac06b

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8637,7 +8637,7 @@ static void ggml_compute_forward_ssm_scan_f32(
86378637
// n_head
86388638
for (int h = ih0; h < ih1; ++h) {
86398639
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8640-
const float dt_soft_plus = softplus(dt[h]);
8640+
const float dt_soft_plus = ggml_softplus(dt[h]);
86418641
const float dA = expf(dt_soft_plus * A[h]);
86428642
const int g = h / (nh / ng); // repeat_interleave
86438643

@@ -8734,7 +8734,7 @@ static void ggml_compute_forward_ssm_scan_f32(
87348734
// n_head
87358735
for (int h = ih0; h < ih1; ++h) {
87368736
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8737-
const float dt_soft_plus = softplus(dt[h]);
8737+
const float dt_soft_plus = ggml_softplus(dt[h]);
87388738
const int g = h / (nh / ng); // repeat_interleave
87398739

87408740
// dim

ggml/src/ggml-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) {
102102
}
103103
}
104104

105-
static inline float softplus(float input) {
105+
static inline float ggml_softplus(float input) {
106106
return (input > 20.0f) ? input : logf(1 + expf(input));
107107
}
108108
//

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,8 +2664,8 @@ struct ggml_tensor * ggml_xielu(
26642664
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
26652665

26662666
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
2667-
ggml_set_op_params_f32(result, 1, beta + softplus(alpha_n));
2668-
ggml_set_op_params_f32(result, 2, softplus(alpha_p));
2667+
ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
2668+
ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
26692669
ggml_set_op_params_f32(result, 3, beta);
26702670
ggml_set_op_params_f32(result, 4, eps);
26712671

0 commit comments

Comments
 (0)