Skip to content

Commit 7e61692

Browse files
committed
Setup to handle adjoints
1 parent 3b4f4ed commit 7e61692

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.21.1"
4+
version = "2.22.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
910
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1011
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
@@ -16,6 +17,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1617
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
20+
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1921
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2022
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2123
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"

src/LinearSolve.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ PrecompileTools.@recompile_invalidations begin
2424
using DocStringExtensions
2525
using EnumX
2626
using Requires
27+
using Markdown
28+
using ChainRulesCore
2729
import InteractiveUtils
2830

2931
import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
@@ -43,6 +45,8 @@ PrecompileTools.@recompile_invalidations begin
4345
import Preferences
4446
end
4547

48+
const CRC = ChainRulesCore
49+
4650
if Preferences.@load_preference("LoadMKL_JLL", true)
4751
using MKL_jll
4852
const usemkl = MKL_jll.is_available()
@@ -124,6 +128,7 @@ include("solve_function.jl")
124128
include("default.jl")
125129
include("init.jl")
126130
include("extension_algs.jl")
131+
include("adjoint.jl")
127132
include("deprecated.jl")
128133

129134
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
@@ -236,4 +241,6 @@ export MetalLUFactorization
236241

237242
export OperatorAssumptions, OperatorCondition
238243

244+
export LinearSolveAdjoint
245+
239246
end

src/adjoint.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
2+
# TODO: Document the options in LinearSolveAdjoint
3+
4+
@doc doc"""
5+
LinearSolveAdjoint(; linsolve = nothing)
6+
7+
Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:
8+
9+
```math
10+
\begin{align}
11+
A^T \lambda &= \partial x \\
12+
\partial A &= -\lambda x^T \\
13+
\partial b &= \lambda
14+
\end{align}
15+
```
16+
17+
For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf).
18+
19+
## Choice of Linear Solver
20+
21+
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
22+
forward solve (this is done by keeping the linsolve as `nothing`). For example, if the
23+
forward solve was performed via a Factorization, then we can reuse the factorization for the
24+
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
25+
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
26+
"""
27+
@kwdef struct LinearSolveAdjoint{L} <:
28+
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central}
29+
linsolve::L = nothing
30+
end
31+
32+
CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...)
33+
34+
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache)
35+
sensealg = cache.sensealg
36+
37+
# Decide if we need to cache the
38+
39+
sol = solve!(cache)
40+
function ∇solve!(∂sol)
41+
@assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported."
42+
43+
∂cache = NoTangent()
44+
return NoTangent(), ∂cache
45+
end
46+
return sol, ∇solve!
47+
end

src/common.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565
__issquare(assump::OperatorAssumptions) = assump.issq
6666
__conditioning(assump::OperatorAssumptions) = assump.condition
6767

68-
mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
68+
mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
6969
A::TA
7070
b::Tb
7171
u::Tu
@@ -80,6 +80,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
8080
maxiters::Int
8181
verbose::Bool
8282
assumptions::OperatorAssumptions{issq}
83+
sensealg::S
8384
end
8485

8586
function Base.setproperty!(cache::LinearCache, name::Symbol, x)
@@ -137,6 +138,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
137138
Pl = IdentityOperator(size(prob.A)[1]),
138139
Pr = IdentityOperator(size(prob.A)[2]),
139140
assumptions = OperatorAssumptions(issquare(prob.A)),
141+
sensealg = LinearSolveAdjoint(),
140142
kwargs...)
141143
@unpack A, b, u0, p = prob
142144

@@ -170,8 +172,9 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
170172
Tc = typeof(cacheval)
171173

172174
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
173-
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
174-
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
175+
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
176+
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
177+
maxiters, verbose, assumptions, sensealg)
175178
return cache
176179
end
177180

0 commit comments

Comments
 (0)