@@ -211,30 +211,28 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst
211211 }
212212}
213213
214- // Functor for XIELU operation with parameters
215214struct op_xielu_functor {
216215 float alpha_n, alpha_p, beta, eps;
217216
218217 __host__ __device__ __forceinline__ op_xielu_functor (float a_n, float a_p, float b, float e)
219218 : alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {}
220219
221220 __device__ __forceinline__ float operator ()(float x) const {
222- float gate_pos = (x > 0 .0f ); // positive branch gate
221+ const float gate_pos = (x > 0 .0f ); // positive branch gate
223222
224223 // Positive branch: alpha_p * v^2 + beta * v
225- float y_pos = alpha_p * x * x + beta * x;
224+ const float y_pos = alpha_p * x * x + beta * x;
226225
227226 // Negative branch:
228- float min_v_eps = fminf (x, eps); // works fine even if eps < 0
229- float y_neg = (expm1f (min_v_eps) - x) * alpha_n + beta * x;
227+ const float min_v_eps = fminf (x, eps); // works fine even if eps < 0
228+ const float y_neg = (expm1f (min_v_eps) - x) * alpha_n + beta * x;
230229
231230 // Select the appropriate branch based on the gate
232231 return gate_pos * y_pos + (1 .0f - gate_pos) * y_neg;
233232 }
234233};
235234
236235// swiglu_oai
237-
238236template <typename T>
239237static __global__ void swiglu_oai_kernel (const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
240238 const int64_t i = int64_t (blockDim .x )*blockIdx .x + threadIdx .x ;
0 commit comments