Skip to content

Commit c153903

Browse files
author
Avik Pal
committed
Allow special solver for adjoint
1 parent 7671369 commit c153903

File tree

5 files changed

+142
-39
lines changed

5 files changed

+142
-39
lines changed

src/LinearSolve.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ PrecompileTools.@recompile_invalidations begin
2323
using FastLapackInterface
2424
using DocStringExtensions
2525
using EnumX
26-
using Requires
2726
using Markdown
2827
using ChainRulesCore
2928
import InteractiveUtils

src/adjoint.jl

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
22

33
@doc doc"""
4-
LinearSolveAdjoint(; linsolve = nothing)
4+
LinearSolveAdjoint(; linsolve = missing)
55
66
Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:
77
@@ -18,53 +18,49 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi
1818
## Choice of Linear Solver
1919
2020
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
21-
forward solve (this is done by keeping the linsolve as `nothing`). For example, if the
21+
forward solve (this is done by keeping the linsolve as `missing`). For example, if the
2222
forward solve was performed via a Factorization, then we can reuse the factorization for the
2323
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
2424
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
2525
"""
2626
@kwdef struct LinearSolveAdjoint{L} <:
2727
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central}
28-
linsolve::L = nothing
28+
linsolve::L = missing
2929
end
3030

31-
function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem,
32-
alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
31+
function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
32+
alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A(
33+
alg, prob.A, prob.b), kwargs...)
34+
# sol = solve(prob, alg, args...; kwargs...)
3335
cache = init(prob, alg, args...; kwargs...)
34-
function ∇init(∂cache)
35-
∂∅ = NoTangent()
36-
∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p)
37-
∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p)
38-
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)
39-
end
40-
return cache, ∇init
41-
end
36+
(; A, sensealg) = cache
4237

43-
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;
44-
kwargs...)
45-
(; A, b, sensealg) = cache
38+
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."
4639

4740
# Decide if we need to cache `A` and `b` for the reverse pass
48-
if sensealg.linsolve === nothing
41+
if sensealg.linsolve === missing
4942
# We can reuse the factorization so no copy is needed
5043
# Krylov Methods don't modify `A`, so it's safe to just reuse it
5144
# No Copy is needed even for the default case
5245
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
5346
alg isa DefaultLinearSolver)
54-
A_ = cache.alias_A ? deepcopy(A) : A
47+
A_ = alias_A ? deepcopy(A) : A
5548
end
5649
else
57-
error("Not Implemented Yet!!!")
50+
if alg isa DefaultLinearSolver
51+
A_ = deepcopy(A)
52+
else
53+
A_ = alias_A ? deepcopy(A) : A
54+
end
5855
end
5956

60-
# Forward Solve
61-
sol = solve!(cache, alg, args...; kwargs...)
57+
sol = solve!(cache)
58+
59+
function ∇linear_solve(∂sol)
60+
∂∅ = NoTangent()
6261

63-
function ∇solve!(∂sol)
64-
@assert !cache.isfresh "`cache.A` has been updated between the forward and the \
65-
reverse pass. This is not supported."
6662
∂u = ∂sol.u
67-
if sensealg.linsolve === nothing
63+
if sensealg.linsolve === missing
6864
λ = if cache.cacheval isa Factorization
6965
cache.cacheval' \ ∂u
7066
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
@@ -79,25 +75,23 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;
7975
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
8076
end
8177
else
82-
error("Not Implemented Yet!!!")
78+
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
79+
λ = solve(
80+
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
8381
end
8482

8583
∂A = -λ * transpose(sol.u)
8684
∂b = λ
87-
∂∅ = NoTangent()
88-
89-
∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol,
90-
cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg)
85+
∂prob = LinearProblem(∂A, ∂b, ∂∅)
9186

92-
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
87+
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)
9388
end
94-
return sol, ∇solve!
89+
90+
return sol, ∇linear_solve
9591
end
9692

9793
function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
9894
prob = LinearProblem(A, b, p)
99-
function ∇prob(∂prob)
100-
return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p
101-
end
95+
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
10296
return prob, ∇prob
10397
end

src/common.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,15 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
180180
end
181181

182182
function SciMLBase.solve(prob::LinearProblem, args...; kwargs...)
183-
solve!(init(prob, nothing, args...; kwargs...))
183+
return solve(prob, nothing, args...; kwargs...)
184184
end
185185

186-
function SciMLBase.solve(prob::LinearProblem,
187-
alg::Union{SciMLLinearSolveAlgorithm, Nothing},
186+
function SciMLBase.solve(prob::LinearProblem, ::Nothing, args...;
187+
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
188+
return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...)
189+
end
190+
191+
function SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
188192
args...; kwargs...)
189193
solve!(init(prob, alg, args...; kwargs...))
190194
end

test/adjoint.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
using Zygote, ForwardDiff
2+
using LinearSolve, LinearAlgebra, Test
3+
using FiniteDiff
4+
5+
n = 4
6+
A = rand(n, n);
7+
b1 = rand(n);
8+
9+
function f(A, b1; alg = LUFactorization())
10+
prob = LinearProblem(A, b1)
11+
12+
sol1 = solve(prob, alg)
13+
14+
s1 = sol1.u
15+
norm(s1)
16+
end
17+
18+
f(A, b1) # Uses BLAS
19+
20+
dA, db1 = Zygote.gradient(f, A, b1)
21+
22+
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
23+
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
24+
25+
@test dA dA2
26+
@test db1 db12
27+
28+
A = rand(n, n);
29+
b1 = rand(n);
30+
31+
_ff = (x, y) -> f(x,
32+
y;
33+
alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization))
34+
_ff(copy(A), copy(b1))
35+
36+
dA, db1 = Zygote.gradient(_ff, copy(A), copy(b1))
37+
38+
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
39+
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
40+
41+
@test dA dA2
42+
@test db1 db12
43+
44+
function f3(A, b1, b2; alg = KrylovJL_GMRES())
45+
prob = LinearProblem(A, b1)
46+
sol1 = solve(prob, alg)
47+
prob = LinearProblem(A, b2)
48+
sol2 = solve(prob, alg)
49+
norm(sol1.u .+ sol2.u)
50+
end
51+
52+
dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)
53+
54+
#= Needs ForwardDiff rules
55+
dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
56+
db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
57+
db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1))
58+
59+
@test dA ≈ dA2 atol=5e-5
60+
@test db1 ≈ db12
61+
@test db2 ≈ db22
62+
=#
63+
64+
A = rand(n, n);
65+
b1 = rand(n);
66+
for alg in (
67+
LUFactorization(),
68+
RFLUFactorization(),
69+
KrylovJL_GMRES()
70+
)
71+
@show alg
72+
function fb(b)
73+
prob = LinearProblem(A, b)
74+
75+
sol1 = solve(prob, alg)
76+
77+
sum(sol1.u)
78+
end
79+
fb(b1)
80+
81+
fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
82+
@show fd_jac
83+
84+
zyg_jac = Zygote.jacobian(fb, b1) |> first |> vec
85+
@show zyg_jac
86+
87+
@test zyg_jacfd_jac rtol=1e-4
88+
89+
function fA(A)
90+
prob = LinearProblem(A, b1)
91+
92+
sol1 = solve(prob, alg)
93+
94+
sum(sol1.u)
95+
end
96+
fA(A)
97+
98+
fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
99+
@show fd_jac
100+
101+
zyg_jac = Zygote.jacobian(fA, A) |> first |> vec
102+
@show zyg_jac
103+
104+
@test zyg_jacfd_jac rtol=1e-4
105+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core"
1515
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
1616
@time @safetestset "Default Alg Tests" include("default_algs.jl")
1717
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
18+
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
1819
@time @safetestset "Traits" include("traits.jl")
1920
@time @safetestset "BandedMatrices" include("banded.jl")
2021
@time @safetestset "Static Arrays" include("static_arrays.jl")

0 commit comments

Comments
 (0)