@@ -86,30 +86,6 @@ static __device__ __forceinline__ float op_elu(float x) {
8686 return (x > 0 .f ) ? x : expm1f (x);
8787}
8888
89- static __device__ __forceinline__ float op_reglu (float x) {
90- return fmaxf (x, 0 );
91- }
92-
93- static __device__ __forceinline__ float op_geglu (float x) {
94- const float GELU_COEF_A = 0 .044715f ;
95- const float SQRT_2_OVER_PI = 0 .79788456080286535587989211986876f ;
96- return 0 .5f *x*(1 .0f + tanhf (SQRT_2_OVER_PI*x*(1 .0f + GELU_COEF_A*x*x)));
97- }
98-
99- static __device__ __forceinline__ float op_swiglu (float x) {
100- return x / (1 .0f + expf (-x));
101- }
102-
103- static __device__ __forceinline__ float op_geglu_erf (float x) {
104- const float SQRT_2_INV = 0 .70710678118654752440084436210484f ;
105- return 0 .5f *x*(1 .0f + erff (x*SQRT_2_INV));
106- }
107-
108- static __device__ __forceinline__ float op_geglu_quick (float x) {
109- const float GELU_QUICK_COEF = -1 .702f ;
110- return x * (1 .0f / (1 .0f + expf (GELU_QUICK_COEF * x)));
111- }
112-
11389// Special operation functions (stay in original locations)
11490static __device__ __forceinline__ float op_silu_back (float grad, float x) {
11591 const float s = 1 .0f / (1 .0f + expf (-x));
@@ -407,26 +383,28 @@ void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
407383 ggml_cuda_op_unary (ctx, dst, op_elu);
408384}
409385
386+ // GLU
410387void ggml_cuda_op_reglu (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
411- ggml_cuda_op_unary_gated (ctx, dst, op_reglu );
388+ ggml_cuda_op_unary_gated (ctx, dst, op_relu );
412389}
413390
414391void ggml_cuda_op_geglu (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
415- ggml_cuda_op_unary_gated (ctx, dst, op_geglu );
392+ ggml_cuda_op_unary_gated (ctx, dst, op_gelu );
416393}
417394
418395void ggml_cuda_op_swiglu (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
419- ggml_cuda_op_unary_gated (ctx, dst, op_swiglu );
396+ ggml_cuda_op_unary_gated (ctx, dst, op_silu );
420397}
421398
422399void ggml_cuda_op_geglu_erf (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
423- ggml_cuda_op_unary_gated (ctx, dst, op_geglu_erf );
400+ ggml_cuda_op_unary_gated (ctx, dst, op_gelu_erf );
424401}
425402
426403void ggml_cuda_op_geglu_quick (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
427- ggml_cuda_op_unary_gated (ctx, dst, op_geglu_quick );
404+ ggml_cuda_op_unary_gated (ctx, dst, op_gelu_quick );
428405}
429406
407+ // xIELU
430408void ggml_cuda_op_xielu (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
431409 // Get the XIELU parameters from the operation
432410 const float * op_params = (const float *)dst->op_params ;
0 commit comments