Skip to content

Commit 18a0b43

Browse files
authored
Reuse intermediate result of primal in derivative (#52)
1 parent e663fe4 commit 18a0b43

File tree

2 files changed

+76
-21
lines changed

2 files changed

+76
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.14"
4+
version = "0.3.15"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/chainrules.jl

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,97 @@
1-
ChainRulesCore.@scalar_rule(xlogx(x::Real), (1 + log(x),))
2-
ChainRulesCore.@scalar_rule(xlogy(x::Real, y::Real), (log(y), x / y,))
3-
ChainRulesCore.@scalar_rule(xlog1py(x::Real, y::Real), (log1p(y), x / (1 + y),))
1+
function _Ω_∂_xlogx(x::Real)
2+
logx = log(x)
3+
y = x * logx
4+
Ω = iszero(x) ? zero(y) : y
5+
∂x = 1 + logx
6+
return Ω, ∂x
7+
end
8+
function ChainRulesCore.frule((_, Δx), ::typeof(xlogx), x::Real)
9+
Ω, ∂x = _Ω_∂_xlogx(x)
10+
ΔΩ = ∂x * Δx
11+
return Ω, ΔΩ
12+
end
13+
function ChainRulesCore.rrule(::typeof(xlogx), x::Real)
14+
Ω, ∂x = _Ω_∂_xlogx(x)
15+
xlogx_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ)
16+
return Ω, xlogx_pullback
17+
end
418

5-
function ChainRulesCore.frule((_, Δx), ::typeof(xexpx), x::Real)
19+
function _Ω_∂_xlogy(x::Real, y::Real)
20+
logy = log(y)
21+
z = x * logy
22+
Ω = iszero(x) && !isnan(y) ? zero(z) : z
23+
∂x = logy
24+
∂y = x / y
25+
return Ω, ∂x, ∂y
26+
end
27+
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xlogy), x::Real, y::Real)
28+
Ω, ∂x, ∂y = _Ω_∂_xlogy(x, y)
29+
ΔΩ = muladd(∂x, Δx, ∂y * Δy)
30+
return Ω, ΔΩ
31+
end
32+
function ChainRulesCore.rrule(::typeof(xlogy), x::Real, y::Real)
33+
Ω, ∂x, ∂y = _Ω_∂_xlogy(x, y)
34+
xlogy_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ, ∂y * ΔΩ)
35+
return Ω, xlogy_pullback
36+
end
37+
38+
function _Ω_∂_xlog1py(x::Real, y::Real)
39+
log1py = log1p(y)
40+
z = x * log1py
41+
Ω = iszero(x) && !isnan(y) ? zero(z) : z
42+
∂x = log1py
43+
∂y = x / (1 + y)
44+
return Ω, ∂x, ∂y
45+
end
46+
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xlog1py), x::Real, y::Real)
47+
Ω, ∂x, ∂y = _Ω_∂_xlog1py(x, y)
48+
ΔΩ = muladd(∂x, Δx, ∂y * Δy)
49+
return Ω, ΔΩ
50+
end
51+
function ChainRulesCore.rrule(::typeof(xlog1py), x::Real, y::Real)
52+
Ω, ∂x, ∂y = _Ω_∂_xlog1py(x, y)
53+
xlog1py_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ, ∂y * ΔΩ)
54+
return Ω, xlog1py_pullback
55+
end
56+
57+
function _Ω_∂_xexpx(x::Real)
658
expx = exp(x)
759
if iszero(expx)
860
Ω = expx
9-
ΔΩ = expx * Δx
61+
∂x = expx
1062
else
1163
Ω = x * expx
12-
ΔΩ = (1 + x) * expx * Δx
64+
∂x = (1 + x) * expx
1365
end
66+
return Ω, ∂x
67+
end
68+
function ChainRulesCore.frule((_, Δx), ::typeof(xexpx), x::Real)
69+
Ω, ∂x = _Ω_∂_xexpx(x)
70+
ΔΩ = ∂x * Δx
1471
return Ω, ΔΩ
1572
end
16-
1773
function ChainRulesCore.rrule(::typeof(xexpx), x::Real)
18-
expx = exp(x)
19-
Ω = iszero(expx) ? expx : x * expx
20-
function xexpx_pullback(ΔΩ)
21-
Δx = iszero(expx) ? expx * ΔΩ : (1 + x) * expx * ΔΩ
22-
return (ChainRulesCore.NoTangent(), Δx)
23-
end
74+
Ω, ∂x = _Ω_∂_xexpx(x)
75+
xexpx_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ)
2476
return Ω, xexpx_pullback
2577
end
2678

27-
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xexpy), x::Real, y::Real)
79+
function _Ω_∂_xexpy(x::Real, y::Real)
2880
expy = exp(y)
2981
result = x * expy
3082
Ω = iszero(expy) && !isnan(x) ? zero(result) : result
31-
ΔΩ = expy * Δx + Ω * Δy
83+
∂x = expy
84+
∂y = Ω
85+
return Ω, ∂x, ∂y
86+
end
87+
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xexpy), x::Real, y::Real)
88+
Ω, ∂x, ∂y = _Ω_∂_xexpy(x, y)
89+
ΔΩ = muladd(∂x, Δx, ∂y * Δy)
3290
return Ω, ΔΩ
3391
end
34-
3592
function ChainRulesCore.rrule(::typeof(xexpy), x::Real, y::Real)
36-
expy = exp(y)
37-
result = x * expy
38-
Ω = iszero(expy) && !isnan(x) ? zero(result) : result
39-
xexpy_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ΔΩ * expy, ΔΩ * Ω)
93+
Ω, ∂x, ∂y = _Ω_∂_xexpy(x, y)
94+
xexpy_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ, ∂y * ΔΩ)
4095
return Ω, xexpy_pullback
4196
end
4297

0 commit comments

Comments
 (0)