Skip to content

Commit 9f04163

Browse files
committed
fix: NNlib activations handling made generic
1 parent c9ca274 commit 9f04163

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ for (jlop, hloop) in (
1919
end
2020
end
2121

22-
NNlib.relu(x::TracedRArray{T,0}) where {T} = max(x, zero(T))
23-
24-
function NNlib.gelu(x::TracedRArray{T,0}) where {T}
25-
α = T(0.044715)
26-
λλ = T((8 / π))
27-
return x * sigmoid(λλ * x * muladd(x^2, α, one(T)))
22+
# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays
23+
for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, ))
24+
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
25+
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
26+
end
2827
end
2928

3029
# TODO handle non finite cases

0 commit comments

Comments
 (0)