Skip to content

Commit a6158f9

Browse files
authored
Add ChainRules definitions (#25)
1 parent be0bae2 commit a6158f9

File tree

5 files changed

+152
-2
lines changed

5 files changed

+152
-2
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
89
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011

1112
[compat]
13+
ChainRulesCore = "1"
1214
DocStringExtensions = "0.8"
1315
IrrationalConstants = "0.1"
1416
julia = "1"
1517

1618
[extras]
19+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
1720
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
21+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1822
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1923

2024
[targets]
21-
test = ["OffsetArrays", "Test"]
25+
test = ["ChainRulesTestUtils", "OffsetArrays", "Random", "Test"]

src/LogExpFunctions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module LogExpFunctions
33
using DocStringExtensions: SIGNATURES
44
using Base: Math.@horner
55

6+
import ChainRulesCore
67
import IrrationalConstants
78
import LinearAlgebra
89

@@ -12,5 +13,6 @@ export xlogx, xlogy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, log
1213

1314
include("basicfuns.jl")
1415
include("logsumexp.jl")
16+
include("chainrules.jl")
1517

1618
end # module

src/chainrules.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
ChainRulesCore.@scalar_rule(xlogx(x::Real), (1 + log(x),))
2+
ChainRulesCore.@scalar_rule(xlogy(x::Real, y::Real), (log(y), x / y,))
3+
4+
ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),))
5+
ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),))
6+
ChainRulesCore.@scalar_rule(log1psq(x::Real), (2 * x / (1 + x^2),))
7+
ChainRulesCore.@scalar_rule(log1pexp(x::Real), (logistic(x),))
8+
ChainRulesCore.@scalar_rule(log1mexp(x::Real), (-exp(x - Ω),))
9+
ChainRulesCore.@scalar_rule(log2mexp(x::Real), (-exp(x - Ω),))
10+
ChainRulesCore.@scalar_rule(logexpm1(x::Real), (exp(x - Ω),))
11+
12+
ChainRulesCore.@scalar_rule(logaddexp(x::Real, y::Real), (exp(x - Ω), exp(y - Ω)))
13+
ChainRulesCore.@scalar_rule(
14+
logsubexp(x::Real, y::Real),
15+
(x > y ? exp(x - Ω) : -exp(x - Ω), x > y ? -exp(y - Ω) : exp(y - Ω)),
16+
)
17+
18+
function ChainRulesCore.frule((_, Δx), ::typeof(logsumexp), x::AbstractArray{<:Real}; dims=:)
19+
Ω = logsumexp(x; dims=dims)
20+
ΔΩ = sum(exp.(x .- Ω) .* Δx; dims=dims)
21+
return Ω, ΔΩ
22+
end
23+
function ChainRulesCore.rrule(::typeof(logsumexp), x::AbstractArray{<:Real}; dims=:)
24+
Ω = logsumexp(x; dims=dims)
25+
project_x = ChainRulesCore.ProjectTo(x)
26+
function logsumexp_pullback(Ω̄)
27+
= ChainRulesCore.InplaceableThunk(
28+
Δ -> Δ .+= Ω̄ .* exp.(x .- Ω),
29+
ChainRulesCore.@thunk(project_x(Ω̄ .* exp.(x .- Ω))),
30+
)
31+
return ChainRulesCore.NoTangent(), x̄
32+
end
33+
return Ω, logsumexp_pullback
34+
end
35+
36+
function ChainRulesCore.frule(
37+
(_, _, Δx), ::typeof(softmax!), r::AbstractArray{<:Real}, x::AbstractArray{<:Real},
38+
)
39+
softmax!(r, x)
40+
_Δx = reshape(Δx, size(r))
41+
Δr = r .* (_Δx .- LinearAlgebra.dot(r, _Δx))
42+
return r, Δr
43+
end
44+
function ChainRulesCore.rrule(
45+
::typeof(softmax!), r::AbstractArray{<:Real}, x::AbstractArray{<:Real},
46+
)
47+
softmax!(r, x)
48+
project_x = ChainRulesCore.ProjectTo(x)
49+
rcopy = copy(reshape(r, size(x)))
50+
function softmax!_pullback(r̄)
51+
_r̄ = reshape(r̄, size(rcopy))
52+
= ChainRulesCore.InplaceableThunk(
53+
Δ -> Δ .+= rcopy .* (_r̄ .- LinearAlgebra.dot(rcopy, _r̄)),
54+
ChainRulesCore.@thunk(project_x(rcopy .* (_r̄ .- LinearAlgebra.dot(rcopy, _r̄)))),
55+
)
56+
return ChainRulesCore.NoTangent(), ChainRulesCore.ZeroTangent(), x̄
57+
end
58+
return r, softmax!_pullback
59+
end

test/chainrules.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
@testset "chainrules.jl" begin
2+
x = rand()
3+
test_frule(xlogx, x)
4+
test_rrule(xlogx, x)
5+
for x in (-x, 0.0, x)
6+
y = rand()
7+
test_frule(xlogy, x, y)
8+
test_rrule(xlogy, x, y)
9+
end
10+
11+
test_frule(logit, x)
12+
test_rrule(logit, x)
13+
14+
for x in (-randexp(), randexp())
15+
test_frule(log1psq, x)
16+
test_rrule(log1psq, x)
17+
end
18+
19+
# test all `Float64` and `Float32` branches of `logistic`
20+
for x in (-821.4, -23.5, 12.3, 41.2)
21+
test_frule(logistic, x)
22+
test_rrule(logistic, x)
23+
end
24+
for x in (-123.2f0, -21.4f0, 8.3f0, 21.5f0)
25+
test_frule(logistic, x; rtol=1f-3, atol=1f-3)
26+
test_rrule(logistic, x; rtol=1f-3, atol=1f-3)
27+
end
28+
29+
# test all branches of `log1pexp`
30+
for x in (-20.9, 15.4, 41.5)
31+
test_frule(log1pexp, x)
32+
test_rrule(log1pexp, x)
33+
end
34+
for x in (8.3f0, 12.5f0, 21.2f0)
35+
test_frule(log1pexp, x; rtol=1f-3, atol=1f-3)
36+
test_rrule(log1pexp, x; rtol=1f-3, atol=1f-3)
37+
end
38+
39+
for x in (-10.2, -3.3, -0.3)
40+
test_frule(log1mexp, x)
41+
test_rrule(log1mexp, x)
42+
end
43+
44+
for x in (-10.2, -3.3, -0.3, 0.5)
45+
test_frule(log2mexp, x)
46+
test_rrule(log2mexp, x)
47+
end
48+
49+
# test all branches of `logexpm1`
50+
for x in (5.2, 21.4, 41.5)
51+
test_frule(logexpm1, x)
52+
test_rrule(logexpm1, x)
53+
end
54+
for x in (4.3f0, 12.5f0, 21.2f0)
55+
test_frule(logexpm1, x; rtol=1f-3, atol=1f-3)
56+
test_rrule(logexpm1, x; rtol=1f-3, atol=1f-3)
57+
end
58+
59+
for x in (-randexp(), randexp()), y in (-randexp(), randexp())
60+
test_frule(logaddexp, x, y)
61+
test_rrule(logaddexp, x, y)
62+
63+
test_frule(logsubexp, x, y)
64+
test_rrule(logsubexp, x, y)
65+
end
66+
67+
for x in (randn(10), randn(10, 8)), dims in (:, 1, 1:2, 2)
68+
dims isa Colon || all(d <= ndims(x) for d in dims) || continue
69+
test_frule(logsumexp, x; fkwargs=(dims=dims,))
70+
test_rrule(logsumexp, x; fkwargs=(dims=dims,))
71+
end
72+
73+
for x in (randn(10), randn(10, 8))
74+
for r in (similar(x), similar(x, 1, size(x)...))
75+
test_frule(softmax!, r, x)
76+
test_rrule(softmax!, r, x)
77+
end
78+
end
79+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
using LogExpFunctions
2+
using ChainRulesTestUtils
23
using OffsetArrays
4+
5+
using Random
36
using Test
47

8+
Random.seed!(1234)
9+
510
include("basicfuns.jl")
11+
include("chainrules.jl")

0 commit comments

Comments
 (0)