Skip to content

Commit aad1cfe

Browse files
committed
fix: temporarily remove conditional in softmax
1 parent 137fb2d commit aad1cfe

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) whe
1010
x = T.(Reactant.materialize_traced_array(x))
1111
max_ = maximum(x; dims)
1212
diff = exp.(x .- max_)
13-
@trace if all(isfinite, max_)
14-
@. out = diff
15-
else
16-
@. out = ifelse(isinf(max_), ifelse(isinf(x), T(1), T(0)), diff)
17-
end
13+
# TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581
14+
# fixed
15+
# @trace if all(isfinite, max_)
16+
@. out = diff
17+
# else
18+
# @. out = ifelse(isinf(max_), ifelse(isinf(x), T(1), T(0)), diff)
19+
# end
1820
out ./= sum(out; dims)
1921
return out
2022
end
@@ -23,11 +25,13 @@ function NNlib.logsoftmax!(out::AnyTracedRArray{T}, x::AbstractArray; dims=1) wh
2325
x = T.(Reactant.materialize_traced_array(x))
2426
max_ = maximum(x; dims)
2527
diff = x .- max_
26-
@trace if all(isfinite, max_)
27-
@. out = diff
28-
else
29-
@. out = ifelse(isinf(max_), ifelse(isinf(x), T(0), -T(Inf)), diff)
30-
end
28+
# TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581
29+
# fixed
30+
# @trace if all(isfinite, max_)
31+
@. out = diff
32+
# else
33+
# @. out = ifelse(isinf(max_), ifelse(isinf(x), T(0), -T(Inf)), diff)
34+
# end
3135
out .-= log.(sum(exp, out; dims))
3236
return out
3337
end

0 commit comments

Comments
 (0)