Skip to content

Commit 17cea1c

Browse files
authored
WIP: Issue 408 (#416)
* WIP * WIP * WIP; work to close #408 * forward diff extension * oops * adjustment for testing * version bump
1 parent 83476df commit 17cea1c

File tree

7 files changed

+187
-28
lines changed

7 files changed

+187
-28
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ docs/site
66
test/benchmarks.json
77
Manifest.toml
88
TODO.md
9+
default.profraw

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Roots"
22
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
3-
version = "2.1.1"
3+
version = "2.1.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

ext/RootsForwardDiffExt.jl

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,60 @@
1-
21
module RootsForwardDiffExt
32

43
using Roots
54
using ForwardDiff
6-
import ForwardDiff: Dual, value, partials
5+
import ForwardDiff: Dual, value, partials, Partials, derivative, gradient!
76

8-
# For ForwardDiff we add a `solve` method for Dual types
9-
# TODO (Issue #384) ForwardDiff.hessian fails, but this works:
10-
#function hess(f, p)
11-
# ∇(p) = ForwardDiff.gradient(f, p)
12-
# ForwardDiff.jacobian(∇, p)
13-
#end
7+
# What works
8+
# F(p) = find_zero(f, x0, M, p)
9+
# G(p) = find_zero(𝐺(p), x0, M)
10+
# F G
11+
# ForwardDiff.derivative ✓ x (wrong answer, 0.0)
12+
# ForwardDiff.gradient ✓ x (wrong answer, 0.0)
13+
# ForwardDiff.hessian ✓ x (wrong answer, 0.0)
14+
# Zygote.gradient ✓ ✓
15+
# Zygote.hessian ✓ x (wrong answer!)
16+
# Zygote.hessian_reverse ✓ x (MethodError)
1417

1518
function Roots.solve(ZP::ZeroProblem,
1619
M::Roots.AbstractUnivariateZeroMethod,
17-
𝐩::Union{Dual{T},
18-
AbstractArray{<:Dual{T,<:Real}}
19-
};
20+
𝐩::Dual{T};
2021
kwargs...) where {T}
2122

23+
24+
# p_and_dp = 𝐩
25+
p, dp = value.(𝐩), partials.(𝐩)
26+
27+
xᵅ = solve(ZP, M, p; kwargs...)
28+
2229
f = ZP.F
23-
pᵥ = value.(𝐩)
24-
xᵅ = solve(ZP, M, pᵥ; kwargs...)
25-
𝐱ᵅ = Dual{T}(xᵅ, one(xᵅ))
30+
fₓ = derivative(_x -> f(_x, p), xᵅ)
31+
fₚ = derivative(_p -> f(xᵅ, _p), p)
2632

27-
fₓ = partials(f(𝐱ᵅ, pᵥ), 1)
28-
fₚ = partials(f(xᵅ, 𝐩))
29-
Dual{T}(xᵅ, - fₚ / fₓ)
33+
# x and dx
34+
dx = - (fₚ * dp) / fₓ
35+
36+
Dual{T}(xᵅ, dx)
3037
end
3138

39+
# cf https://discourse.julialang.org/t/custom-rule-for-differentiating-through-newton-solver-using-forwarddiff-works-for-gradient-fails-for-hessian/93002/22
40+
function Roots.solve(ZP::ZeroProblem,
41+
M::Roots.AbstractUnivariateZeroMethod,
42+
𝐩::AbstractArray{<:Dual{T,R,N}};
43+
kwargs...) where {T,R,N}
44+
45+
46+
# p_and_dp = 𝐩
47+
p, dp = value.(𝐩), partials.(𝐩)
48+
xᵅ = solve(ZP, M, p; kwargs...)
49+
50+
f = ZP.F
51+
fₓ = derivative(_x -> f(_x, p), xᵅ)
52+
fₚ = similar(𝐩) # <-- need this, not output of gradient(p->f(x,p), p)
53+
gradient!(fₚ, _p -> f(xᵅ, _p), p)
54+
55+
# x_and_dx
56+
dx = - (fₚ' * dp) / fₓ
57+
58+
Dual{T}(xᵅ, Partials(ntuple(k -> dx[k], Val(N))))
59+
end
3260
end

src/chain_rules.jl

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,33 @@
33
# ∇f = 0 => ∂/∂ₓ f(xᵅ, p) ⋅ ∂xᵅ/∂ₚ + ∂/∂ₚf(x\^α, p) ⋅ I = 0
44
# or ∂xᵅ/∂ₚ = - ∂/∂ₚ f(xᵅ, p) / ∂/∂ₓ f(xᵅ, p)
55

6+
# There are two cases considered
7+
# F(p) = find_zero(f(x,p), x₀, M, p) # f a function
8+
# G(p) = find_zero(𝐺(p), x₀, M) # 𝐺 a functor
9+
# For G(p) first order derivatives are working
10+
# **but** hessian is not with Zygote. *MOREOVER* it fails
11+
# with the **wrong answer** not an error.
12+
#
13+
# (`Zygote.hessian` calls `ForwardDiff` and that isn't working with a functor;
14+
# `Zygote.hessian_reverse` doesn't seem to work here, though perhaps
15+
# that is fixable.)
16+
17+
18+
# this assumes a function and a parameter `p` passed in
19+
import ChainRulesCore: Tangent, NoTangent, frule, rrule
620
function ChainRulesCore.frule(
721
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
822
(_, _, _, Δp),
923
::typeof(solve),
1024
ZP::ZeroProblem,
11-
M::AbstractUnivariateZeroMethod,
25+
M::Roots.AbstractUnivariateZeroMethod,
1226
p;
1327
kwargs...,
1428
)
1529
xᵅ = solve(ZP, M, p; kwargs...)
1630

1731
# Use a single reverse-mode AD call with `rrule_via_ad` if `config` supports it?
18-
F = p -> Callable_Function(M, ZP.F, p)
32+
F = p -> Roots.Callable_Function(M, ZP.F, p)
1933
fₓ(x) = first(F(p)(x))
2034
fₚ(p) = first(F(p)(xᵅ))
2135
fx = ChainRulesCore.frule_via_ad(config, (ChainRulesCore.NoTangent(), true), fₓ, xᵅ)[2]
@@ -24,23 +38,59 @@ function ChainRulesCore.frule(
2438
xᵅ, -fp / fx
2539
end
2640

41+
# Case of Functor carrying parameters
42+
ChainRulesCore.frule(
43+
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
44+
xdots,
45+
::typeof(solve),
46+
ZP::Roots.ZeroProblem,
47+
M::Roots.AbstractUnivariateZeroMethod,
48+
::Nothing;
49+
kwargs...,
50+
) =
51+
frule(config, xdots, solve, ZP, M; kwargs...)
52+
53+
function ChainRulesCore.frule(
54+
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
55+
(_, Δq, _),
56+
::typeof(solve),
57+
ZP::Roots.ZeroProblem,
58+
M::Roots.AbstractUnivariateZeroMethod;
59+
kwargs...,
60+
)
61+
# no `p`; make ZP.F the parameter (issue 408)
62+
foo = ZP.F
63+
zprob2 = ZeroProblem(|>, ZP.x₀)
64+
nms = fieldnames(typeof(foo))
65+
nt = NamedTuple{nms}(getfield(foo, n) for n nms)
66+
dfoo = Tangent{typeof(foo)}(;nt...)
67+
68+
return frule(config,
69+
(NoTangent(), NoTangent(), NoTangent(), dfoo),
70+
Roots.solve, zprob2, M, foo)
71+
end
72+
73+
74+
##
75+
2776
## modified from
2877
## https://github.com/gdalle/ImplicitDifferentiation.jl/blob/main/src/implicit_function.jl
78+
# this is for passing a parameter `p`
2979
function ChainRulesCore.rrule(
3080
rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
3181
::typeof(solve),
3282
ZP::ZeroProblem,
33-
M::AbstractUnivariateZeroMethod,
83+
M::Roots.AbstractUnivariateZeroMethod,
3484
p;
3585
kwargs...,
3686
)
3787
xᵅ = solve(ZP, M, p; kwargs...)
3888

39-
f(x, p) = first(Callable_Function(M, ZP.F, p)(x))
89+
f(x, p) = first(Roots.Callable_Function(M, ZP.F, p)(x))
4090
_, pullback_f = ChainRulesCore.rrule_via_ad(rc, f, xᵅ, p)
4191
_, fx, fp = pullback_f(true)
42-
yp = -fp / fx
4392

93+
yp = -fp / fx
4494
function pullback_solve_ZeroProblem(dy)
4595
dp = yp * dy
4696
return (
@@ -53,3 +103,51 @@ function ChainRulesCore.rrule(
53103

54104
return xᵅ, pullback_solve_ZeroProblem
55105
end
106+
107+
# this assumes a functor 𝐺(p) for the function *and* no parameter
108+
ChainRulesCore.rrule(
109+
rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
110+
::typeof(solve),
111+
ZP::ZeroProblem,
112+
M::Roots.AbstractUnivariateZeroMethod,
113+
::Nothing;
114+
kwargs...,
115+
) =
116+
ChainRulesCore.rrule(rc, solve, ZP, M; kwargs...)
117+
118+
119+
function ChainRulesCore.rrule(
120+
rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
121+
::typeof(solve),
122+
ZP::ZeroProblem,
123+
M::Roots.AbstractUnivariateZeroMethod;
124+
kwargs...,
125+
)
126+
127+
128+
𝑍𝑃 = ZeroProblem(|>, ZP.x₀)
129+
xᵅ = solve(ZP, M; kwargs...)
130+
f(x, p) = first(Roots.Callable_Function(M, 𝑍𝑃.F, p)(x))
131+
132+
_, pullback_f = ChainRulesCore.rrule_via_ad(rc, f, xᵅ, ZP.F)
133+
_, fx, fp = pullback_f(true)
134+
135+
yp = NamedTuple{keys(fp)}(-fₚ/fx for fₚ values(fp))
136+
137+
function pullback_solve_ZeroProblem(dy)
138+
dF = ChainRulesCore.Tangent{typeof(ZP.F)}(; yp...)
139+
140+
dZP = ChainRulesCore.Tangent{typeof(ZP)}(;
141+
F = dF,
142+
x₀ = ChainRulesCore.NoTangent()
143+
)
144+
145+
dsolve = ChainRulesCore.NoTangent()
146+
dM = ChainRulesCore.NoTangent()
147+
dp = ChainRulesCore.NoTangent()
148+
149+
return dsolve, dZP, dM, dp
150+
end
151+
152+
return xᵅ, pullback_solve_ZeroProblem
153+
end

test/test_chain_rules.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ using Zygote
44
using Test
55

66
# issue #325 add frule, rrule
7+
struct 𝐺
8+
p
9+
end
10+
(g::𝐺)(x) = cos(x) - g.p * x
11+
G₃(p) = find_zero(𝐺(p), (0, pi/2), Bisection())
12+
F₃(p) = find_zero((x,p) -> cos(x) - p*x, (0, pi/2), Bisection(), p)
13+
714

815
@testset "Test frule and rrule" begin
916
# Type inference tests of `test_frule` and `test_rrule` with the default
@@ -33,7 +40,7 @@ using Test
3340
G(p) = find_zero(g, 1, Order1(), p)
3441
@test first(Zygote.gradient(G, [0, 4])) [1 / 2, 1 / 4]
3542

36-
# a tuple of functions
43+
# a tuple of functions
3744
fx(x, p) = 1 / x
3845
test_frule(solve, ZeroProblem((f, fx), 1), Roots.Newton(), 1.0; check_inferred=false)
3946
test_rrule(solve, ZeroProblem((f, fx), 1), Roots.Newton(), 1.0; check_inferred=false)
@@ -67,4 +74,27 @@ using Test
6774
)
6875
G2(p) = find_zero((g, gx), 1, Roots.Newton(), p)
6976
@test first(Zygote.gradient(G2, [0, 4])) [1 / 2, 1 / 4]
77+
78+
# test Functor; issue #408
79+
x = rand()
80+
@test first(Zygote.gradient(F₃, x)) first(Zygote.gradient(G₃, x))
81+
# ForwardDiff extension makes this fail.
82+
VERSION >= v"1.9.0" && @test_broken first(Zygote.hessian(F₃, x)) first(Zygote.hessian(G₃, x))
83+
# test_frule, test_rrule aren't successful
84+
#=
85+
# DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 2
86+
test_frule(
87+
solve,
88+
ZeroProblem(𝐺(2), (0.0, pi/2)),
89+
Roots.Bisection();
90+
check_inferred=false,
91+
)
92+
# MethodError: no method matching keys(::NoTangent)
93+
test_rrule(
94+
solve,
95+
ZeroProblem(𝐺(2), (0.0, pi/2)),
96+
Roots.Bisection();
97+
check_inferred=false,
98+
)
99+
=#
70100
end

test/test_extensions.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,17 @@ using ForwardDiff
5656
@test ForwardDiff.derivative(F, p) 1 / (2sqrt(p))
5757
end
5858

59-
# Hessian is *broken*
59+
# Hessian is *fixed* for F(p) = find_zero(f, x₀, M, p)
6060
f = (x, p) -> x^2 - sum(p .^ 2)
6161
Z = ZeroProblem(f, (0, 1000))
6262
F = p -> solve(Z, Roots.Bisection(), p)
6363
Z = ZeroProblem(f, (0, 1000))
6464
F = p -> solve(Z, Roots.Bisection(), p)
65-
hess(f, p) = ForwardDiff.jacobian(p -> ForwardDiff.gradient(F, p), p)
6665
for p ([1,2], [1,3], [1,4])
6766
@test F(p) sqrt(sum(p .^ 2))
68-
@test_throws DimensionMismatch ForwardDiff.hessian(F, p)
6967
a, b = p
7068
n = sqrt(a^2 + b^2)^3
71-
@test hess(F, p) [b^2 -a*b; -a*b a^2] / n
69+
@test ForwardDiff.hessian(F, p) [b^2 -a*b; -a*b a^2] / n
7270
end
7371
end
7472

tmp/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3+
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
4+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

0 commit comments

Comments
 (0)