@@ -10,11 +10,13 @@ function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) whe
10
10
x = T .(Reactant. materialize_traced_array (x))
11
11
max_ = maximum (x; dims)
12
12
diff = exp .(x .- max_)
13
- @trace if all (isfinite, max_)
14
- @. out = diff
15
- else
16
- @. out = ifelse (isinf (max_), ifelse (isinf (x), T (1 ), T (0 )), diff)
17
- end
13
+ # TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581
14
+ # fixed
15
+ # @trace if all(isfinite, max_)
16
+ @. out = diff
17
+ # else
18
+ # @. out = ifelse(isinf(max_), ifelse(isinf(x), T(1), T(0)), diff)
19
+ # end
18
20
out ./= sum (out; dims)
19
21
return out
20
22
end
@@ -23,11 +25,13 @@ function NNlib.logsoftmax!(out::AnyTracedRArray{T}, x::AbstractArray; dims=1) wh
23
25
x = T .(Reactant. materialize_traced_array (x))
24
26
max_ = maximum (x; dims)
25
27
diff = x .- max_
26
- @trace if all (isfinite, max_)
27
- @. out = diff
28
- else
29
- @. out = ifelse (isinf (max_), ifelse (isinf (x), T (0 ), - T (Inf )), diff)
30
- end
28
+ # TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581
29
+ # fixed
30
+ # @trace if all(isfinite, max_)
31
+ @. out = diff
32
+ # else
33
+ # @. out = ifelse(isinf(max_), ifelse(isinf(x), T(0), -T(Inf)), diff)
34
+ # end
31
35
out .- = log .(sum (exp, out; dims))
32
36
return out
33
37
end
0 commit comments