Skip to content

Commit 0c8396e

Browse files
authored
Simplify softmax, test second derivatives (#393)
* simplify softmax, test second derivatives * add a note about Flux to docstring * add Tracker to downstream tests * missing semicolon * remove x arguments, rename * move exports * change the notation * tidy, add x::AbstractArray * add fastmath * also logsumexp * version, and trigger CI * upgrade ci
1 parent 172549c commit 0c8396e

File tree

7 files changed

+163
-75
lines changed

7 files changed

+163
-75
lines changed

.buildkite/pipeline.yml

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,19 @@ steps:
2424
NNLIB_TEST_CUDA: true
2525
timeout_in_minutes: 60
2626

27-
## Add these when julia 1.7 is out
28-
# - label: "GPU julia v1"
29-
# plugins:
30-
# - JuliaCI/julia#v1:
31-
# version: "1"
32-
# - JuliaCI/julia-test#v1: ~
33-
# - JuliaCI/julia-coverage#v1:
34-
# codecov: true
35-
# dirs:
36-
# - src
37-
# agents:
38-
# queue: "juliagpu"
39-
# cuda: "*"
40-
# timeout_in_minutes: 60
27+
- label: "GPU julia v1"
28+
plugins:
29+
- JuliaCI/julia#v1:
30+
version: "1"
31+
- JuliaCI/julia-test#v1: ~
32+
- JuliaCI/julia-coverage#v1:
33+
codecov: true
34+
dirs:
35+
- src
36+
agents:
37+
queue: "juliagpu"
38+
cuda: "*"
39+
timeout_in_minutes: 60
4140

4241
# - label: "GPU julia nightly"
4342
# plugins:

.github/workflows/Downstream.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
os: [ubuntu-latest]
1919
package:
2020
- {user: FluxML, repo: Flux.jl, group: All}
21+
- {user: FluxML, repo: Tracker.jl, group: All}
2122
- {user: denizyuret, repo: Knet.jl, group: All}
2223
- {user: dfdx, repo: Avalon.jl, group: All}
2324
- {user: JuliaOptimalTransport, repo: OptimalTransport.jl, group: All}

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
fail-fast: false
2424
matrix:
2525
version:
26+
- '1.6'
2627
- '1' # automatically expands to the latest stable 1.x release of Julia
2728
- 'nightly'
2829
os:

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlib"
22
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3-
version = "0.8.2"
3+
version = "0.8.3"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313

1414
[compat]
1515
Adapt = "2, 3.2"
16-
ChainRulesCore = "0.9.45, 0.10, 1"
16+
ChainRulesCore = "1.13"
1717
Compat = "3.14"
1818
Requires = "0.5, 1.0"
1919
julia = "1.6"

src/deprecations.jl

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,56 @@
1-
### v0.8 Deprecations
1+
2+
### Deprecated while v0.7 was latest
3+
4+
function ∇softmax(Δ, x; dims = 1)
5+
# This 2-arg version recomputes the forward pass, which is slow.
6+
# Removed from use in 0.7, but only prints a warning during 0.8:
7+
Base.depwarn("`∇softmax(Δ, x)` without `y = softmax(x)` argument is deprecated, as this is inefficient, please use `∇softmax_data(dy, y)`", :∇softmax)
8+
∇softmax(Δ, x, softmax(x; dims); dims)
9+
end
10+
∇softmax!(Δ, x; dims = 1) = Δ .= ∇softmax(Δ, x; dims)
11+
∇softmax!(out, Δ, x; dims = 1) = out .= ∇softmax(Δ, x; dims)
12+
13+
function ∇logsoftmax(Δ, x; dims = 1)
14+
Base.depwarn("`∇logsoftmax(Δ, x)` without `y = logsoftmax(x)` argument is deprecated, please use `∇logsoftmax_data(dy, y)`", :∇logsoftmax)
15+
∇logsoftmax(Δ, x, logsoftmax(x; dims); dims)
16+
end
17+
∇logsoftmax!(Δ, x; dims = 1) = Δ .= ∇logsoftmax(Δ, x; dims)
18+
∇logsoftmax!(out, Δ, x; dims = 1) = out .= ∇logsoftmax(Δ, x; dims)
19+
20+
21+
### Deprecated while v0.8 was latest
22+
23+
export ∇softmax,
24+
∇softmax!,
25+
logsoftmax,
26+
logsoftmax!,
27+
∇logsoftmax,
28+
∇logsoftmax!
29+
30+
function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
31+
x::AbstractArray, y::AbstractArray; dims = 1)
32+
Base.depwarn("`∇softmax!(dx, dy, x, y)` is deprecated, just use `∇softmax_data(dy, y)`", :∇softmax!)
33+
# Removed because using a mutating function blocks 2nd derivatives, and
34+
# the CUDA overload was slow anyway, https://github.com/FluxML/NNlibCUDA.jl/issues/30
35+
out .= Δ .* y
36+
out .= out .- y .* sum(out; dims)
37+
end
38+
39+
function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray,
40+
x::AbstractArray, y::AbstractArray; dims = 1)
41+
Base.depwarn("`∇logsoftmax!(dx, dy, x, y)` is deprecated, just use `∇logsoftmax_data(dy, y)`", :∇softmax!)
42+
out .= Δ .- sum(Δ; dims) .* exp.(y)
43+
end
44+
45+
function ∇softmax(dy::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S}
46+
# Removed because there's no need to close over `x` here, that was done only to distinguish
47+
# this from `∇softmax(Δ, x; dims = 1)` which re-computed `y = softmax(x)`, which is slow.
48+
Base.depwarn("`∇softmax(dy, x, y)` should be replaced with `∇softmax_data(dy, y)`", :∇softmax)
49+
∇softmax_data(dy, y)
50+
end
51+
52+
function ∇logsoftmax(dy::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1)
53+
Base.depwarn("`∇logsoftmax(dy, x, y)` should be replaced with `∇logsoftmax_data(dy, y)`", :∇softmax)
54+
∇logsoftmax_data(dy, y)
55+
end
56+

src/softmax.jl

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
"""
23
softmax(x; dims = 1)
34
@@ -33,45 +34,63 @@ julia> softmax([1 2 3; 2 2 2]; dims=2)
3334
0.0900306 0.244728 0.665241
3435
0.333333 0.333333 0.333333
3536
```
37+
38+
Note that, when used with Flux.jl, `softmax` must not be passed to layers like `Dense`
39+
which accept an activation function. The activation is broadcasted over the result,
40+
thus applies to individual numbers. But `softmax` always needs to see the whole column.
41+
42+
```julia
43+
julia> using Flux
44+
45+
julia> x = randn(Float32, 4, 4, 3, 13);
46+
47+
julia> model = Chain(Conv((4, 4), 3 => 8, tanh), Flux.flatten, Dense(8 => 7), softmax);
48+
49+
julia> model(x) |> size
50+
(7, 13)
51+
52+
julia> Dense(4 => 7, softmax)(x)
53+
ERROR: `softmax(x)` called with a number, but it expects an array.
54+
```
3655
"""
37-
softmax(x; dims = 1) = softmax!(similar(x, (float eltype)(x)), x; dims = dims)
56+
softmax(x::AbstractArray{T}; dims = 1) where {T} = softmax!(similar(x, float(T)), x; dims)
3857

39-
softmax!(x; dims = 1) = softmax!(x, x; dims = dims)
58+
softmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims)
4059

4160
function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
42-
max_ = maximum(x; dims = dims)
61+
max_ = maximum(x; dims)
4362
if all(isfinite, max_)
44-
out .= exp.(x .- max_)
63+
@fastmath out .= exp.(x .- max_)
4564
else
46-
@. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
65+
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
4766
end
48-
out ./= sum(out; dims = dims) # could re-use max_ when dims != (:) and eltype(x) == T.
67+
out ./= sum(out; dims)
4968
end
5069

51-
∇softmax::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} =
52-
∇softmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims)
53-
∇softmax(Δ, x, y; dims = 1) = ∇softmax(unthunk(Δ), x, y, dims = dims)
54-
55-
# Can introduce at the end of deprecation cycle of ∇softmax!(out, Δ, x; dims = 1)
56-
# ∇softmax!(Δ, x, y; dims = 1) = ∇softmax!(Δ, Δ, x, y; dims = dims)
57-
58-
function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
59-
x::AbstractArray, y::AbstractArray; dims = 1)
60-
out .= Δ .* y
61-
out .= out .- y .* sum(out; dims = dims)
70+
function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
71+
dx = if within_grad()
72+
tmp = dy .* y
73+
tmp .- y .* sum(tmp; dims)
74+
else
75+
# This path is faster, only safe for 1st derivatives though.
76+
# Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads,
77+
# but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30
78+
out = similar(y, promote_type(T,S))
79+
out .= dy .* y
80+
out .= out .- y .* sum(out; dims)
81+
end
6282
end
6383

64-
# Old 2-arg version recomputing forward
65-
∇softmax(Δ, x; dims = 1) = ∇softmax(Δ, x, softmax(x, dims = dims); dims = dims)
66-
∇softmax!(Δ, x; dims = 1) = ∇softmax!(Δ, Δ, x, softmax(x, dims = dims); dims = dims)
67-
∇softmax!(out, Δ, x; dims = 1) = ∇softmax!(out, Δ, x, softmax(x, dims = dims); dims = dims)
68-
69-
function rrule(::typeof(softmax), xs; dims=1)
70-
y = softmax(xs; dims=dims)
71-
softmax_pullback(Δ) = (NoTangent(), ∇softmax(unthunk(Δ), xs, y, dims = dims))
84+
function rrule(::typeof(softmax), x; dims = 1)
85+
y = softmax(x; dims)
86+
softmax_pullback(dy) = (NoTangent(), ∇softmax_data(unthunk(dy), y; dims))
7287
return y, softmax_pullback
7388
end
7489

90+
within_grad() = false
91+
rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),)
92+
93+
7594
"""
7695
logsoftmax(x; dims = 1)
7796
@@ -85,52 +104,52 @@ It is semantically equivalent to the following:
85104
86105
See also [`softmax`](@ref).
87106
"""
88-
logsoftmax(x; dims = 1) = logsoftmax!(similar(x, (float eltype)(x)), x; dims = dims)
107+
logsoftmax(x::AbstractArray{T}; dims = 1) where {T} = logsoftmax!(similar(x, float(T)), x; dims)
89108

90-
logsoftmax!(x; dims = 1) = logsoftmax!(x, x; dims = dims)
109+
logsoftmax!(x::AbstractArray; dims = 1) = logsoftmax!(x, x; dims)
91110

92111
function logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
93-
max_ = maximum(x; dims = dims)
112+
max_ = maximum(x; dims)
94113
if all(isfinite, max_)
95114
out .= x .- max_
96115
else
97116
@. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 0, -Inf), x - max_)
98117
end
99-
log_ = log.(sum(exp, out; dims = dims))
118+
@fastmath log_ = log.(sum(exp, out; dims))
100119
out .-= log_
101120
end
102121

103-
∇logsoftmax::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} =
104-
∇logsoftmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims)
105-
∇logsoftmax(Δ, x, y; dims = 1) = ∇logsoftmax(unthunk(Δ), x, y, dims = dims)
106-
107-
# Old 2-arg version recomputing forward
108-
∇logsoftmax(Δ, x; dims = 1) = ∇logsoftmax(Δ, x, logsoftmax(x, dims = dims); dims = dims)
109-
∇logsoftmax!(Δ, x; dims = 1) = ∇logsoftmax!(Δ, Δ, x, logsoftmax(x, dims = dims); dims = dims)
110-
∇logsoftmax!(out, Δ, x; dims = 1) = ∇logsoftmax!(out, Δ, x, logsoftmax(x, dims = dims); dims = dims)
111-
112-
function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray,
113-
x::AbstractArray, y::AbstractArray; dims = 1)
114-
out .= Δ .- sum(Δ, dims = dims) .* exp.(y)
122+
function ∇logsoftmax_data(dy::AbstractArray, y::AbstractArray; dims = 1)
123+
# This was previously `∇logsoftmax!(dx, dy, x, y; dims)` to allow CUDA overloads, but that was slow.
124+
dx = dy .- sum(dy; dims) .* exp.(y)
115125
end
116-
117-
function rrule(::typeof(logsoftmax), xs; dims=1)
118-
y = logsoftmax(xs; dims=dims)
119-
logsoftmax_pullback(Δ) = (NoTangent(), logsoftmax(unthunk(Δ), xs, y, dims = dims))
126+
127+
function rrule(::typeof(logsoftmax), x; dims = 1)
128+
y = logsoftmax(x; dims)
129+
logsoftmax_pullback(dy) = (NoTangent(), logsoftmax_data(unthunk(dy), y; dims))
120130
return y, logsoftmax_pullback
121131
end
122132

123133
"""
124134
logsumexp(x; dims = :)
125135
126-
Computes `log.(sum(exp.(x); dims = dims))` in a numerically stable
127-
way.
136+
Computes `log.(sum(exp.(x); dims))` in a numerically stable way.
137+
Without `dims` keyword this returns a scalar.
128138
129139
See also [`logsoftmax`](@ref).
130140
"""
131141
function logsumexp(x::AbstractArray; dims = :)
132-
max_ = maximum(x; dims = dims)
133-
max_ .+ log.(sum(exp.(x .- max_); dims = dims))
142+
max_ = maximum(x; dims)
143+
@fastmath max_ .+ log.(sum(exp.(x .- max_); dims))
144+
end
145+
146+
function rrule(::typeof(logsumexp), x; dims = :)
147+
# The gradient is `softmax`, but both compute `tmp` so it's worth saving.
148+
max_ = maximum(x; dims)
149+
@fastmath tmp = exp.(x .- max_)
150+
@fastmath y = max_ .+ log.(sum(tmp; dims))
151+
logsumexp_pullback(dy) = (NoTangent(), unthunk(dy) .* tmp ./ sum(tmp; dims))
152+
return y, logsumexp_pullback
134153
end
135154

136155
# Informative error message if any of the softmax variants is called with a number

test/softmax.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Statistics: mean
2+
using NNlib: ∇softmax_data, ∇logsoftmax_data
23

34
@testset "softmax integer input" begin
45
@test softmax(Int[0, 0]) == [0.5, 0.5]
@@ -34,10 +35,10 @@ end
3435
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.0]
3536

3637
y = logsoftmax(xs)
37-
@test logsoftmax(ones(Float32, size(xs)), xs, y) Float32[1 1 1; -1 -1 -1]
38+
@test logsoftmax_data(ones(Float32, size(xs)), y) Float32[1 1 1; -1 -1 -1]
3839

3940
y = softmax(xs)
40-
@test softmax(ones(Float32, size(xs)), xs, y) zeros(Float32, size(xs))
41+
@test softmax_data(ones(Float32, size(xs)), y) zeros(Float32, size(xs))
4142

4243
# These values precalculated using PyTorch's nn.LogSoftmax
4344
xs = [
@@ -52,10 +53,10 @@ end
5253
]
5354

5455
y = logsoftmax(xs)
55-
@test logsoftmax(ones(size(xs)), xs, y) ys rtol = 1e-6
56+
@test logsoftmax_data(ones(size(xs)), y) ys rtol = 1e-6
5657

5758
y = softmax(xs)
58-
@test softmax(ones(size(xs)), xs, y) zeros(size(xs)) atol = 1e-6
59+
@test softmax_data(ones(size(xs)), y) zeros(size(xs)) atol = 1e-6
5960
end
6061

6162
@testset "softmax with Inf, NaN" begin
@@ -91,12 +92,12 @@ end
9192
@testset "$fn(Float64, $(size(xs)))" for fn in [zeros, ones, rand]
9293
Δ = fn(Float64, size(xs))
9394
y = softmax(xs)
94-
∇softmax!(out, Δ, xs, y)
95-
@test out softmax(Δ, xs, y) rtol = 1e-6
95+
∇softmax!(out, Δ, xs, y) # deprecated
96+
@test out softmax_data, y) rtol = 1e-6
9697

9798
y = logsoftmax(xs)
98-
∇logsoftmax!(out, Δ, xs, y)
99-
@test out logsoftmax(Δ, xs, y) rtol = 1e-6
99+
∇logsoftmax!(out, Δ, xs, y) # deprecated
100+
@test out logsoftmax_data, y) rtol = 1e-6
100101
end
101102
end
102103
end
@@ -109,14 +110,14 @@ end
109110
@test logsumexp(x; dims = 1) flogsoft(x, dims = 1)
110111
end
111112

112-
113113
@testset "AutoDiff" begin
114114
for f in (softmax, logsoftmax), d in (:, 1, 2)
115115
gradtest(f, (3,4); fkwargs = (dims = d,), check_rrule = true)
116116
end
117117
gradtest(x -> softmax(x) .* (1:3), 3)
118118
gradtest(x -> softmax(x) .* (1:3), (3,5), atol = 1e-4)
119119
gradtest(x -> softmax(x, dims = 2) .* (1:3), (3,5), atol = 1e-4)
120+
120121
gradtest(x -> logsoftmax(x) .* (1:3), 3)
121122
gradtest(x -> logsoftmax(x) .* (1:3), (3,5))
122123
gradtest(x -> logsoftmax(x, dims = 2) .* (1:3), (3,5))
@@ -125,3 +126,15 @@ end
125126
gradtest(logsumexp, (3,4), fkwargs = (dims = d,))
126127
end
127128
end
129+
130+
@testset "Second derivatives" begin
131+
x = [1 2 3; 6 5 4]
132+
H = Zygote.hessian_dual(x -> sum(sin, softmax(x)), x)
133+
@test H Zygote.hessian_reverse(x -> sum(sin, softmax(x)), x)
134+
135+
H2 = Zygote.hessian_dual(x -> sum(sin, logsoftmax(x)), x)
136+
@test H2 Zygote.hessian_reverse(x -> sum(sin, logsoftmax(x)), x)
137+
138+
H3 = Zygote.hessian_dual(x -> sum(sin, logsumexp(x)), x)
139+
@test H3 Zygote.hessian_reverse(x -> sum(sin, logsumexp(x)), x)
140+
end

0 commit comments

Comments
 (0)