Skip to content

Commit 65765bb

Browse files
committed
fix: match gelu with paper implementation
1 parent 871b1ed commit 65765bb

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,25 @@ for (jlop, hloop) in (
66
@eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x)
77
end
88

9+
# See https://github.com/EnzymeAD/Reactant.jl/issues/1420
10+
# Without this we will never fuse the gelu into gemm
11+
if isdefined(NNlib, :gelu_tanh)
12+
function NNlib.gelu_tanh(x::TracedRNumber)
13+
α = NNlib.oftf(x, 0.044715)
14+
half = NNlib.oftf(x, 0.5)
15+
λ = sqrt(NNlib.oftf(x, 2 / pi))
16+
return x * half * (1 + tanh* (x + α * x^3)))
17+
end
18+
else
19+
# Older versions of NNlib do not have gelu_tanh (gelu refers to the tanh version)
20+
function NNlib.gelu(x::TracedRNumber)
21+
α = NNlib.oftf(x, 0.044715)
22+
half = NNlib.oftf(x, 0.5)
23+
λ = sqrt(NNlib.oftf(x, 2 / pi))
24+
return x * half * (1 + tanh* (x + α * x^3)))
25+
end
26+
end
27+
928
function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
1029
x = T.(Reactant.materialize_traced_array(x))
1130
max_ = maximum(x; dims)

0 commit comments

Comments
 (0)