Skip to content

Commit 7b090b4

Browse files
Merge pull request #449 from SciML/ap/adjoint
Adjoints for Linear Solve
2 parents a206054 + e937e67 commit 7b090b4

File tree

9 files changed

+261
-21
lines changed

9 files changed

+261
-21
lines changed

Project.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.24.0"
4+
version = "2.25.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
910
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1011
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
@@ -16,6 +17,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1617
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
20+
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1921
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2022
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2123
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
@@ -64,6 +66,7 @@ ArrayInterface = "7.7"
6466
BandedMatrices = "1.5"
6567
BlockDiagonals = "0.1.42"
6668
CUDA = "5"
69+
ChainRulesCore = "1.22"
6770
ConcreteStructs = "0.2.3"
6871
DocStringExtensions = "0.9.3"
6972
EnumX = "1.0.4"
@@ -85,6 +88,7 @@ KrylovKit = "0.6"
8588
Libdl = "1.10"
8689
LinearAlgebra = "1.10"
8790
MPI = "0.20"
91+
Markdown = "1.10"
8892
Metal = "0.5"
8993
MultiFloats = "1"
9094
Pardiso = "0.5"
@@ -96,7 +100,7 @@ RecursiveArrayTools = "3.8"
96100
RecursiveFactorization = "0.2.14"
97101
Reexport = "1"
98102
SafeTestsets = "0.1"
99-
SciMLBase = "2.23.0"
103+
SciMLBase = "2.26.3"
100104
SciMLOperators = "0.3.7"
101105
Setfield = "1"
102106
SparseArrays = "1.10"
@@ -106,6 +110,7 @@ StaticArrays = "1.5"
106110
StaticArraysCore = "1.4.2"
107111
Test = "1"
108112
UnPack = "1"
113+
Zygote = "0.6.69"
109114
julia = "1.10"
110115

111116
[extras]
@@ -133,6 +138,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
133138
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
134139
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
135140
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
141+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
136142

137143
[targets]
138-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs"]
144+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]

ext/LinearSolveHYPREExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using HYPRE.LibHYPRE: HYPRE_Complex
55
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
66
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
77
OperatorAssumptions, default_tol, init_cacheval, __issquare,
8-
__conditioning
8+
__conditioning, LinearSolveAdjoint
99
using SciMLBase: LinearProblem, SciMLBase
1010
using UnPack: @unpack
1111
using Setfield: @set!
@@ -68,6 +68,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
6868
Pl = LinearAlgebra.I,
6969
Pr = LinearAlgebra.I,
7070
assumptions = OperatorAssumptions(),
71+
sensealg = LinearSolveAdjoint(),
7172
kwargs...)
7273
@unpack A, b, u0, p = prob
7374

@@ -89,10 +90,9 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
8990
cache = LinearCache{
9091
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
9192
typeof(Pl), typeof(Pr), typeof(reltol),
92-
typeof(__issquare(assumptions))
93+
typeof(__issquare(assumptions)), typeof(sensealg)
9394
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
94-
maxiters,
95-
verbose, assumptions)
95+
maxiters, verbose, assumptions, sensealg)
9696
return cache
9797
end
9898

src/LinearSolve.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ PrecompileTools.@recompile_invalidations begin
2323
using FastLapackInterface
2424
using DocStringExtensions
2525
using EnumX
26+
using Markdown
27+
using ChainRulesCore
2628
import InteractiveUtils
2729

2830
import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
@@ -42,6 +44,8 @@ PrecompileTools.@recompile_invalidations begin
4244
import Preferences
4345
end
4446

47+
const CRC = ChainRulesCore
48+
4549
if Preferences.@load_preference("LoadMKL_JLL", true)
4650
using MKL_jll
4751
const usemkl = MKL_jll.is_available()
@@ -125,6 +129,7 @@ include("solve_function.jl")
125129
include("default.jl")
126130
include("init.jl")
127131
include("extension_algs.jl")
132+
include("adjoint.jl")
128133
include("deprecated.jl")
129134

130135
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
@@ -240,4 +245,6 @@ export MetalLUFactorization
240245

241246
export OperatorAssumptions, OperatorCondition
242247

248+
export LinearSolveAdjoint
249+
243250
end

src/adjoint.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
2+
3+
@doc doc"""
4+
LinearSolveAdjoint(; linsolve = missing)
5+
6+
Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:
7+
8+
```math
9+
\begin{align}
10+
A^T \lambda &= \partial x \\
11+
\partial A &= -\lambda x^T \\
12+
\partial b &= \lambda
13+
\end{align}
14+
```
15+
16+
For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf).
17+
18+
## Choice of Linear Solver
19+
20+
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
21+
forward solve (this is done by keeping the linsolve as `missing`). For example, if the
22+
forward solve was performed via a Factorization, then we can reuse the factorization for the
23+
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
24+
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
25+
"""
26+
@kwdef struct LinearSolveAdjoint{L} <:
27+
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central}
28+
linsolve::L = missing
29+
end
30+
31+
function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
32+
alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A(
33+
alg, prob.A, prob.b), kwargs...)
34+
# sol = solve(prob, alg, args...; kwargs...)
35+
cache = init(prob, alg, args...; kwargs...)
36+
(; A, sensealg) = cache
37+
38+
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."
39+
40+
# Decide if we need to cache `A` and `b` for the reverse pass
41+
if sensealg.linsolve === missing
42+
# We can reuse the factorization so no copy is needed
43+
# Krylov Methods don't modify `A`, so it's safe to just reuse it
44+
# No Copy is needed even for the default case
45+
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
46+
alg isa DefaultLinearSolver)
47+
A_ = alias_A ? deepcopy(A) : A
48+
end
49+
else
50+
A_ = deepcopy(A)
51+
end
52+
53+
sol = solve!(cache)
54+
55+
function ∇linear_solve(∂sol)
56+
∂∅ = NoTangent()
57+
58+
∂u = ∂sol.u
59+
if sensealg.linsolve === missing
60+
λ = if cache.cacheval isa Factorization
61+
cache.cacheval' \ ∂u
62+
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
63+
first(cache.cacheval)' \ ∂u
64+
elseif alg isa AbstractKrylovSubspaceMethod
65+
invprob = LinearProblem(transpose(cache.A), ∂u)
66+
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
67+
elseif alg isa DefaultLinearSolver
68+
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
69+
else
70+
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
71+
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
72+
end
73+
else
74+
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
75+
λ = solve(
76+
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
77+
end
78+
79+
∂A = -λ * transpose(sol.u)
80+
∂b = λ
81+
∂prob = LinearProblem(∂A, ∂b, ∂∅)
82+
83+
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)
84+
end
85+
86+
return sol, ∇linear_solve
87+
end
88+
89+
function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
90+
prob = LinearProblem(A, b, p)
91+
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
92+
return prob, ∇prob
93+
end

src/common.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565
__issquare(assump::OperatorAssumptions) = assump.issq
6666
__conditioning(assump::OperatorAssumptions) = assump.condition
6767

68-
mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
68+
mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
6969
A::TA
7070
b::Tb
7171
u::Tu
@@ -80,6 +80,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
8080
maxiters::Int
8181
verbose::Bool
8282
assumptions::OperatorAssumptions{issq}
83+
sensealg::S
8384
end
8485

8586
function Base.setproperty!(cache::LinearCache, name::Symbol, x)
@@ -138,6 +139,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
138139
Pl = IdentityOperator(size(prob.A)[1]),
139140
Pr = IdentityOperator(size(prob.A)[2]),
140141
assumptions = OperatorAssumptions(issquare(prob.A)),
142+
sensealg = LinearSolveAdjoint(),
141143
kwargs...)
142144
@unpack A, b, u0, p = prob
143145

@@ -171,17 +173,22 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
171173
Tc = typeof(cacheval)
172174

173175
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
174-
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
175-
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
176+
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
177+
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
178+
maxiters, verbose, assumptions, sensealg)
176179
return cache
177180
end
178181

179182
function SciMLBase.solve(prob::LinearProblem, args...; kwargs...)
180-
solve!(init(prob, nothing, args...; kwargs...))
183+
return solve(prob, nothing, args...; kwargs...)
181184
end
182185

183-
function SciMLBase.solve(prob::LinearProblem,
184-
alg::Union{SciMLLinearSolveAlgorithm, Nothing},
186+
function SciMLBase.solve(prob::LinearProblem, ::Nothing, args...;
187+
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
188+
return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...)
189+
end
190+
191+
function SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
185192
args...; kwargs...)
186193
solve!(init(prob, alg, args...; kwargs...))
187194
end

src/factorization.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -779,26 +779,30 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
779779
cacheval.colptr &&
780780
SparseArrays.decrement(SparseArrays.getrowval(A)) ==
781781
cacheval.rowval)
782-
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
783-
nonzeros(A)), check=false)
782+
fact = lu(
783+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
784+
nonzeros(A)),
785+
check = false)
784786
else
785787
fact = lu!(cacheval,
786788
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
787-
nonzeros(A)), check=false)
789+
nonzeros(A)), check = false)
788790
end
789791
else
790-
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), check=false)
792+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
793+
check = false)
791794
end
792795
cache.cacheval = fact
793796
cache.isfresh = false
794797
end
795798

796-
F = @get_cacheval(cache, :UMFPACKFactorization)
799+
F = @get_cacheval(cache, :UMFPACKFactorization)
797800
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
798801
y = ldiv!(cache.u, F, cache.b)
799802
SciMLBase.build_linear_solution(alg, y, nothing, cache)
800803
else
801-
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache; retcode=ReturnCode.Infeasible)
804+
SciMLBase.build_linear_solution(
805+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
802806
end
803807
end
804808

0 commit comments

Comments
 (0)