1
+
1
2
"""
2
3
softmax(x; dims = 1)
3
4
@@ -33,45 +34,63 @@ julia> softmax([1 2 3; 2 2 2]; dims=2)
33
34
0.0900306 0.244728 0.665241
34
35
0.333333 0.333333 0.333333
35
36
```
37
+
38
+ Note that, when used with Flux.jl, `softmax` must not be passed to layers like `Dense`
39
+ which accept an activation function. The activation is broadcasted over the result,
40
+ thus applies to individual numbers. But `softmax` always needs to see the whole column.
41
+
42
+ ```julia
43
+ julia> using Flux
44
+
45
+ julia> x = randn(Float32, 4, 4, 3, 13);
46
+
47
+ julia> model = Chain(Conv((4, 4), 3 => 8, tanh), Flux.flatten, Dense(8 => 7), softmax);
48
+
49
+ julia> model(x) |> size
50
+ (7, 13)
51
+
52
+ julia> Dense(4 => 7, softmax)(x)
53
+ ERROR: `softmax(x)` called with a number, but it expects an array.
54
+ ```
36
55
"""
37
- softmax (x; dims = 1 ) = softmax! (similar (x, ( float ∘ eltype)(x )), x; dims = dims)
56
+ softmax (x:: AbstractArray{T} ; dims = 1 ) where {T} = softmax! (similar (x, float (T )), x; dims)
38
57
39
- softmax! (x; dims = 1 ) = softmax! (x, x; dims = dims)
58
+ softmax! (x:: AbstractArray ; dims = 1 ) = softmax! (x, x; dims)
40
59
41
60
function softmax! (out:: AbstractArray{T} , x:: AbstractArray ; dims = 1 ) where {T}
42
- max_ = maximum (x; dims = dims )
61
+ max_ = maximum (x; dims)
43
62
if all (isfinite, max_)
44
- out .= exp .(x .- max_)
63
+ @fastmath out .= exp .(x .- max_)
45
64
else
46
- @. out = ifelse (isequal (max_,Inf ), ifelse (isequal (x,Inf ), 1 , 0 ), exp (x - max_))
65
+ @fastmath @ . out = ifelse (isequal (max_,Inf ), ifelse (isequal (x,Inf ), 1 , 0 ), exp (x - max_))
47
66
end
48
- out ./= sum (out; dims = dims) # could re-use max_ when dims != (:) and eltype(x) == T.
67
+ out ./= sum (out; dims)
49
68
end
50
69
51
- ∇softmax (Δ:: AbstractArray{T} , x:: AbstractArray , y:: AbstractArray{S} ; dims = 1 ) where {T,S} =
52
- ∇softmax! (similar (y, promote_type (T, S)), Δ, x, y; dims = dims)
53
- ∇softmax (Δ, x, y; dims = 1 ) = ∇softmax (unthunk (Δ), x, y, dims = dims)
54
-
55
- # Can introduce at the end of deprecation cycle of ∇softmax!(out, Δ, x; dims = 1)
56
- # ∇softmax!(Δ, x, y; dims = 1) = ∇softmax!(Δ, Δ, x, y; dims = dims)
57
-
58
- function ∇softmax! (out:: AbstractArray , Δ:: AbstractArray ,
59
- x:: AbstractArray , y:: AbstractArray ; dims = 1 )
60
- out .= Δ .* y
61
- out .= out .- y .* sum (out; dims = dims)
70
+ function ∇softmax_data (dy:: AbstractArray{T} , y:: AbstractArray{S} ; dims = 1 ) where {T,S}
71
+ dx = if within_grad ()
72
+ tmp = dy .* y
73
+ tmp .- y .* sum (tmp; dims)
74
+ else
75
+ # This path is faster, only safe for 1st derivatives though.
76
+ # Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads,
77
+ # but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30
78
+ out = similar (y, promote_type (T,S))
79
+ out .= dy .* y
80
+ out .= out .- y .* sum (out; dims)
81
+ end
62
82
end
63
83
64
- # Old 2-arg version recomputing forward
65
- ∇softmax (Δ, x; dims = 1 ) = ∇softmax (Δ, x, softmax (x, dims = dims); dims = dims)
66
- ∇softmax! (Δ, x; dims = 1 ) = ∇softmax! (Δ, Δ, x, softmax (x, dims = dims); dims = dims)
67
- ∇softmax! (out, Δ, x; dims = 1 ) = ∇softmax! (out, Δ, x, softmax (x, dims = dims); dims = dims)
68
-
69
- function rrule (:: typeof (softmax), xs; dims= 1 )
70
- y = softmax (xs; dims= dims)
71
- softmax_pullback (Δ) = (NoTangent (), ∇softmax (unthunk (Δ), xs, y, dims = dims))
84
+ function rrule (:: typeof (softmax), x; dims = 1 )
85
+ y = softmax (x; dims)
86
+ softmax_pullback (dy) = (NoTangent (), ∇softmax_data (unthunk (dy), y; dims))
72
87
return y, softmax_pullback
73
88
end
74
89
90
+ within_grad () = false
91
+ rrule (:: typeof (within_grad)) = true , _ -> (NoTangent (),)
92
+
93
+
75
94
"""
76
95
logsoftmax(x; dims = 1)
77
96
@@ -85,52 +104,52 @@ It is semantically equivalent to the following:
85
104
86
105
See also [`softmax`](@ref).
87
106
"""
88
- logsoftmax (x; dims = 1 ) = logsoftmax! (similar (x, ( float ∘ eltype)(x )), x; dims = dims)
107
+ logsoftmax (x:: AbstractArray{T} ; dims = 1 ) where {T} = logsoftmax! (similar (x, float (T )), x; dims)
89
108
90
- logsoftmax! (x; dims = 1 ) = logsoftmax! (x, x; dims = dims)
109
+ logsoftmax! (x:: AbstractArray ; dims = 1 ) = logsoftmax! (x, x; dims)
91
110
92
111
function logsoftmax! (out:: AbstractArray{T} , x:: AbstractArray ; dims = 1 ) where {T}
93
- max_ = maximum (x; dims = dims )
112
+ max_ = maximum (x; dims)
94
113
if all (isfinite, max_)
95
114
out .= x .- max_
96
115
else
97
116
@. out = ifelse (isequal (max_,Inf ), ifelse (isequal (x,Inf ), 0 , - Inf ), x - max_)
98
117
end
99
- log_ = log .(sum (exp, out; dims = dims))
118
+ @fastmath log_ = log .(sum (exp, out; dims))
100
119
out .- = log_
101
120
end
102
121
103
- ∇logsoftmax (Δ:: AbstractArray{T} , x:: AbstractArray , y:: AbstractArray{S} ; dims = 1 ) where {T,S} =
104
- ∇logsoftmax! (similar (y, promote_type (T, S)), Δ, x, y; dims = dims)
105
- ∇logsoftmax (Δ, x, y; dims = 1 ) = ∇logsoftmax (unthunk (Δ), x, y, dims = dims)
106
-
107
- # Old 2-arg version recomputing forward
108
- ∇logsoftmax (Δ, x; dims = 1 ) = ∇logsoftmax (Δ, x, logsoftmax (x, dims = dims); dims = dims)
109
- ∇logsoftmax! (Δ, x; dims = 1 ) = ∇logsoftmax! (Δ, Δ, x, logsoftmax (x, dims = dims); dims = dims)
110
- ∇logsoftmax! (out, Δ, x; dims = 1 ) = ∇logsoftmax! (out, Δ, x, logsoftmax (x, dims = dims); dims = dims)
111
-
112
- function ∇logsoftmax! (out:: AbstractArray , Δ:: AbstractArray ,
113
- x:: AbstractArray , y:: AbstractArray ; dims = 1 )
114
- out .= Δ .- sum (Δ, dims = dims) .* exp .(y)
122
+ function ∇logsoftmax_data (dy:: AbstractArray , y:: AbstractArray ; dims = 1 )
123
+ # This was previously `∇logsoftmax!(dx, dy, x, y; dims)` to allow CUDA overloads, but that was slow.
124
+ dx = dy .- sum (dy; dims) .* exp .(y)
115
125
end
116
-
117
- function rrule (:: typeof (logsoftmax), xs ; dims= 1 )
118
- y = logsoftmax (xs; dims = dims)
119
- logsoftmax_pullback (Δ ) = (NoTangent (), ∇logsoftmax (unthunk (Δ ), xs, y, dims = dims))
126
+
127
+ function rrule (:: typeof (logsoftmax), x ; dims = 1 )
128
+ y = logsoftmax (x; dims)
129
+ logsoftmax_pullback (dy ) = (NoTangent (), ∇logsoftmax_data (unthunk (dy ), y; dims))
120
130
return y, logsoftmax_pullback
121
131
end
122
132
123
133
"""
124
134
logsumexp(x; dims = :)
125
135
126
- Computes `log.(sum(exp.(x); dims = dims ))` in a numerically stable
127
- way .
136
+ Computes `log.(sum(exp.(x); dims))` in a numerically stable way.
137
+ Without `dims` keyword this returns a scalar .
128
138
129
139
See also [`logsoftmax`](@ref).
130
140
"""
131
141
function logsumexp (x:: AbstractArray ; dims = :)
132
- max_ = maximum (x; dims = dims)
133
- max_ .+ log .(sum (exp .(x .- max_); dims = dims))
142
+ max_ = maximum (x; dims)
143
+ @fastmath max_ .+ log .(sum (exp .(x .- max_); dims))
144
+ end
145
+
146
+ function rrule (:: typeof (logsumexp), x; dims = :)
147
+ # The gradient is `softmax`, but both compute `tmp` so it's worth saving.
148
+ max_ = maximum (x; dims)
149
+ @fastmath tmp = exp .(x .- max_)
150
+ @fastmath y = max_ .+ log .(sum (tmp; dims))
151
+ logsumexp_pullback (dy) = (NoTangent (), unthunk (dy) .* tmp ./ sum (tmp; dims))
152
+ return y, logsumexp_pullback
134
153
end
135
154
136
155
# Informative error message if any of the softmax variants is called with a number
0 commit comments