@@ -375,28 +375,6 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
375375 swiglu_oai_cuda (src0_p, src1_p, (float *)dst_d, ggml_nelements (dst), nc, src0_o / sizeof (float ), src1_o / sizeof (float ), alpha, limit, stream);
376376}
377377
378- /* xIELU */
379- struct op_xielu_functor {
380- float alpha_n, alpha_p, beta, eps;
381-
382- __host__ __device__ __forceinline__ op_xielu_functor (float a_n, float a_p, float b, float e)
383- : alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {}
384-
385- __device__ __forceinline__ float operator ()(float x) const {
386- const float gate_pos = (x > 0 .0f ); // positive branch gate
387-
388- // Positive branch: alpha_p * v^2 + beta * v
389- const float y_pos = alpha_p * x * x + beta * x;
390-
391- // Negative branch:
392- const float min_v_eps = fminf (x, eps); // works fine even if eps < 0
393- const float y_neg = (expm1f (min_v_eps) - x) * alpha_n + beta * x;
394-
395- // Select the appropriate branch based on the gate
396- return gate_pos * y_pos + (1 .0f - gate_pos) * y_neg;
397- }
398- };
399-
400378/* CUDA kernel + launcher for xIELU */
401379
402380template <typename T>
@@ -407,7 +385,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp
407385 return ;
408386 }
409387
410- const float xi = x-> type == GGML_TYPE_F32 ? ( float ) x[i] : __half2float ( x[i]) ;
388+ const float xi = sizeof (x[i]) == sizeof (half) ? __half2float ( x[i]) : ( float ) x[i];
411389 const float gate_pos = (xi > 0 .0f );
412390
413391 const float y_pos = alpha_p * xi * xi + beta * xi;
@@ -417,7 +395,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp
417395
418396 const float out = gate_pos * y_pos + (1 .0f - gate_pos) * y_neg;
419397
420- dst[i] = (T) (dst-> type == GGML_TYPE_F32 ? out : __float2half (out));
398+ dst[i] = (T) (sizeof ( dst[i]) == sizeof ( float )) ? out : __float2half (out));
421399}
422400
423401template <typename T>
0 commit comments