Skip to content

Commit 40f2f80

Browse files
committed
Refactor: remove unused methods, inline and factorize softplus, add const modifiers
1 parent 13cc3be commit 40f2f80

File tree

6 files changed

+16
-36
lines changed

6 files changed

+16
-36
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9407,7 +9407,7 @@ static void ggml_compute_forward_ssm_scan_f32(
94079407
// n_head
94089408
for (int h = ih0; h < ih1; ++h) {
94099409
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9410-
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
9410+
const float dt_soft_plus = softplus(dt[h]);
94119411
const float dA = expf(dt_soft_plus * A[h]);
94129412
const int g = h / (nh / ng); // repeat_interleave
94139413

@@ -9504,7 +9504,7 @@ static void ggml_compute_forward_ssm_scan_f32(
95049504
// n_head
95059505
for (int h = ih0; h < ih1; ++h) {
95069506
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9507-
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
9507+
const float dt_soft_plus = softplus(dt[h]);
95089508
const int g = h / (nh / ng); // repeat_interleave
95099509

95109510
// dim

ggml/src/ggml-cpu/unary-ops.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,6 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
194194
unary_op(op_log, params, dst);
195195
}
196196

197-
static float softplus(float input, float beta=1.0f, float threshold=20.0f) {
198-
if (input * beta > threshold) return input;
199-
return (1/beta) * logf(1 + expf(beta * input));
200-
}
201-
202197
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
203198
// Get the XIELU parameters from the operation
204199
float alpha_n = ggml_get_op_params_f32(dst, 1);

ggml/src/ggml-cuda/unary.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,30 +211,28 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst
211211
}
212212
}
213213

214-
// Functor for XIELU operation with parameters
215214
struct op_xielu_functor {
216215
float alpha_n, alpha_p, beta, eps;
217216

218217
__host__ __device__ __forceinline__ op_xielu_functor(float a_n, float a_p, float b, float e)
219218
: alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {}
220219

221220
__device__ __forceinline__ float operator()(float x) const {
222-
float gate_pos = (x > 0.0f); // positive branch gate
221+
const float gate_pos = (x > 0.0f); // positive branch gate
223222

224223
// Positive branch: alpha_p * v^2 + beta * v
225-
float y_pos = alpha_p * x * x + beta * x;
224+
const float y_pos = alpha_p * x * x + beta * x;
226225

227226
// Negative branch:
228-
float min_v_eps = fminf(x, eps); // works fine even if eps < 0
229-
float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x;
227+
const float min_v_eps = fminf(x, eps); // works fine even if eps < 0
228+
const float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x;
230229

231230
// Select the appropriate branch based on the gate
232231
return gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
233232
}
234233
};
235234

236235
// swiglu_oai
237-
238236
template <typename T>
239237
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
240238
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;

ggml/src/ggml-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
8989
return true;
9090
}
9191

92+
static inline float softplus(float input) {
93+
return (input > 20.0f) ? input : logf(1 + expf(input));
94+
}
95+
9296
//
9397
// logging
9498
//

ggml/src/ggml.c

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2646,12 +2646,14 @@ struct ggml_tensor * ggml_silu(
26462646
return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
26472647
}
26482648

2649-
// ggml_xielu
2650-
static float softplus(float input) {
2651-
if (input > 20.0f) return input;
2652-
return logf(1 + expf(input));
2649+
struct ggml_tensor * ggml_silu_inplace(
2650+
struct ggml_context * ctx,
2651+
struct ggml_tensor * a) {
2652+
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
26532653
}
26542654

2655+
// ggml_xielu
2656+
26552657
struct ggml_tensor * ggml_xielu(
26562658
struct ggml_context * ctx,
26572659
struct ggml_tensor * a,
@@ -2673,12 +2675,6 @@ struct ggml_tensor * ggml_xielu(
26732675
return result;
26742676
}
26752677

2676-
struct ggml_tensor * ggml_silu_inplace(
2677-
struct ggml_context * ctx,
2678-
struct ggml_tensor * a) {
2679-
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
2680-
}
2681-
26822678
// ggml_silu_back
26832679

26842680
struct ggml_tensor * ggml_silu_back(

src/llama-model.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18852,19 +18852,6 @@ struct llm_build_smallthinker : public llm_graph_context{
1885218852
}
1885318853
};
1885418854

18855-
// TODO: maybe put this as a general helper in ggml.c?
18856-
static float get_scalar_f32_val(const ggml_tensor *t) {
18857-
float onef;
18858-
if (t->buffer) {
18859-
ggml_backend_tensor_get(t, &onef, 0, sizeof(float));
18860-
} else {
18861-
GGML_ASSERT(t->data);
18862-
onef = *((float *) t->data);
18863-
}
18864-
return onef;
18865-
}
18866-
18867-
// Apertus model graph builder with xIELU activation
1886818855
struct llm_build_apertus : public llm_graph_context {
1886918856
llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1887018857
const int64_t n_embd_head = hparams.n_embd_head_v;

0 commit comments

Comments
 (0)