Skip to content

Commit dc1e4d5

Browse files
committed
CUDA variants, attempt 2
1 parent db9eb29 commit dc1e4d5

File tree

1 file changed

+2
-24
lines changed

1 file changed

+2
-24
lines changed

ggml/src/ggml-cuda/unary.cu

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -375,28 +375,6 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
375375
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
376376
}
377377

378-
/* xIELU */
379-
struct op_xielu_functor {
380-
float alpha_n, alpha_p, beta, eps;
381-
382-
__host__ __device__ __forceinline__ op_xielu_functor(float a_n, float a_p, float b, float e)
383-
: alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {}
384-
385-
__device__ __forceinline__ float operator()(float x) const {
386-
const float gate_pos = (x > 0.0f); // positive branch gate
387-
388-
// Positive branch: alpha_p * v^2 + beta * v
389-
const float y_pos = alpha_p * x * x + beta * x;
390-
391-
// Negative branch:
392-
const float min_v_eps = fminf(x, eps); // works fine even if eps < 0
393-
const float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x;
394-
395-
// Select the appropriate branch based on the gate
396-
return gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
397-
}
398-
};
399-
400378
/* CUDA kernel + launcher for xIELU */
401379

402380
template <typename T>
@@ -407,7 +385,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp
407385
return;
408386
}
409387

410-
const float xi = x->type == GGML_TYPE_F32 ? (float) x[i] : __half2float(x[i]);
388+
const float xi = sizeof(x[i]) == sizeof(half) ? __half2float(x[i]) : (float) x[i];
411389
const float gate_pos = (xi > 0.0f);
412390

413391
const float y_pos = alpha_p * xi * xi + beta * xi;
@@ -417,7 +395,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp
417395

418396
const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
419397

420-
dst[i] = (T) (dst->type == GGML_TYPE_F32 ? out : __float2half(out));
398+
dst[i] = (T) (sizeof(dst[i]) == sizeof(float)) ? out : __float2half(out));
421399
}
422400

423401
template <typename T>

0 commit comments

Comments
 (0)