Skip to content

Commit 68e0c2e

Browse files
authored
fix: provide initial value to iterative linear solver (#182)
* fix: provide initial value to iterative linear solver * Fix * Fix
1 parent f065a08 commit 68e0c2e

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek"]
4-
version = "0.9.0"
4+
version = "0.9.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ using ImplicitDifferentiation:
1414
# not covered by Codecov for now
1515
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
1616

17-
struct ImplicitPullback{TA,TB,TL,TP,Nargs}
17+
struct ImplicitPullback{TA,TB,TL,TC,TP,Nargs}
1818
Aᵀ::TA
1919
Bᵀ::TB
2020
linear_solver::TL
21+
c0::TC
2122
project_x::TP
2223
_Nargs::Val{Nargs}
2324
end
2425

25-
function (pb::ImplicitPullback{TA,TB,TL,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,Nargs}
26-
(; Aᵀ, Bᵀ, linear_solver, project_x) = pb
27-
dc = linear_solver(Aᵀ, -unthunk(dy))
26+
function (pb::ImplicitPullback{TA,TB,TL,TC,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,TC,Nargs}
27+
(; Aᵀ, Bᵀ, linear_solver, c0, project_x) = pb
28+
dc = linear_solver(Aᵀ, -unthunk(dy), c0)
2829
dx = Bᵀ(dc)
2930
df = NoTangent()
3031
dargs = ntuple(unimplemented_tangent, Val(Nargs))
@@ -37,14 +38,15 @@ function ChainRulesCore.rrule(
3738
(; conditions, linear_solver) = implicit
3839
y, z = implicit(x, args...)
3940
c = conditions(x, y, z, args...)
41+
c0 = zero(c)
4042

4143
suggested_backend = chainrules_suggested_backend(rc)
4244
prep = ImplicitFunctionPreparation(eltype(x))
4345
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
4446
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
4547
project_x = ProjectTo(x)
4648

47-
implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, project_x, Val(N))
49+
implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, c0, project_x, Val(N))
4850
return (y, z), implicit_pullback
4951
end
5052

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function (implicit::ImplicitFunction)(
1111
x = value.(x_and_dx)
1212
y, z = implicit(x, args...)
1313
c = implicit.conditions(x, y, z, args...)
14+
y0 = zero(y)
1415

1516
suggested_backend = AutoForwardDiff()
1617
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend)
@@ -21,7 +22,7 @@ function (implicit::ImplicitFunction)(
2122
end
2223
dC = map(B, dX)
2324
dY = map(dC) do dₖc
24-
dₖy = implicit.linear_solver(A, -dₖc)
25+
dₖy = implicit.linear_solver(A, -dₖc, y0)
2526
return dₖy
2627
end
2728

src/settings.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Specify that linear systems `Ax = b` should be solved with a direct method.
1212
"""
1313
struct DirectLinearSolver end
1414

15-
function (solver::DirectLinearSolver)(A, b::AbstractVector)
15+
function (solver::DirectLinearSolver)(A, b::AbstractVector, x0::AbstractVector)
1616
return A \ b
1717
end
1818

@@ -33,8 +33,8 @@ struct IterativeLinearSolver{K}
3333
end
3434
end
3535

36-
function (solver::IterativeLinearSolver)(A, b)
37-
sol, info = linsolve(A, b; solver.kwargs...)
36+
function (solver::IterativeLinearSolver)(A, b, x0)
37+
sol, info = linsolve(A, b, x0; solver.kwargs...)
3838
@assert info.converged == 1
3939
return sol
4040
end

0 commit comments

Comments
 (0)