Skip to content

Commit db9eb29

Browse files
committed
Add float2half / half2float for F16 inputs/outputs
1 parent fc58d3b commit db9eb29

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-cuda/unary.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp
407407
return;
408408
}
409409

410-
const float xi = (float)x[i];
410+
const float xi = x->type == GGML_TYPE_F32 ? (float) x[i] : __half2float(x[i]);
411411
const float gate_pos = (xi > 0.0f);
412412

413413
const float y_pos = alpha_p * xi * xi + beta * xi;
@@ -417,7 +417,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp
417417

418418
const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
419419

420-
dst[i] = (T)out;
420+
dst[i] = (T) (dst->type == GGML_TYPE_F32 ? out : __float2half(out));
421421
}
422422

423423
template <typename T>

0 commit comments

Comments
 (0)