1
1
# TODO : Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
2
2
3
3
@doc doc"""
4
- LinearSolveAdjoint(; linsolve = nothing )
4
+ LinearSolveAdjoint(; linsolve = missing )
5
5
6
6
Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:
7
7
@@ -18,53 +18,49 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi
18
18
## Choice of Linear Solver
19
19
20
20
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
21
- forward solve (this is done by keeping the linsolve as `nothing `). For example, if the
21
+ forward solve (this is done by keeping the linsolve as `missing `). For example, if the
22
22
forward solve was performed via a Factorization, then we can reuse the factorization for the
23
23
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
24
24
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
25
25
"""
26
26
@kwdef struct LinearSolveAdjoint{L} < :
27
27
SciMLBase. AbstractSensitivityAlgorithm{0 , false , :central }
28
- linsolve:: L = nothing
28
+ linsolve:: L = missing
29
29
end
30
30
31
- function CRC. rrule (:: typeof (SciMLBase. init), prob:: LinearProblem ,
32
- alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
31
+ function CRC. rrule (:: typeof (SciMLBase. solve), prob:: LinearProblem ,
32
+ alg:: SciMLLinearSolveAlgorithm , args... ; alias_A = default_alias_A (
33
+ alg, prob. A, prob. b), kwargs... )
34
+ # sol = solve(prob, alg, args...; kwargs...)
33
35
cache = init (prob, alg, args... ; kwargs... )
34
- function ∇init (∂cache)
35
- ∂∅ = NoTangent ()
36
- ∂p = prob. p isa SciMLBase. NullParameters ? prob. p : ProjectTo (prob. p)(∂cache. p)
37
- ∂prob = LinearProblem (∂cache. A, ∂cache. b, ∂p)
38
- return (∂∅, ∂prob, ∂∅, ntuple (_ -> ∂∅, length (args))... )
39
- end
40
- return cache, ∇init
41
- end
36
+ (; A, sensealg) = cache
42
37
43
- function CRC. rrule (:: typeof (SciMLBase. solve!), cache:: LinearCache , alg, args... ;
44
- kwargs... )
45
- (; A, b, sensealg) = cache
38
+ @assert sensealg isa LinearSolveAdjoint " Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."
46
39
47
40
# Decide if we need to cache `A` and `b` for the reverse pass
48
- if sensealg. linsolve === nothing
41
+ if sensealg. linsolve === missing
49
42
# We can reuse the factorization so no copy is needed
50
43
# Krylov Methods don't modify `A`, so it's safe to just reuse it
51
44
# No Copy is needed even for the default case
52
45
if ! (alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
53
46
alg isa DefaultLinearSolver)
54
- A_ = cache . alias_A ? deepcopy (A) : A
47
+ A_ = alias_A ? deepcopy (A) : A
55
48
end
56
49
else
57
- error (" Not Implemented Yet!!!" )
50
+ if alg isa DefaultLinearSolver
51
+ A_ = deepcopy (A)
52
+ else
53
+ A_ = alias_A ? deepcopy (A) : A
54
+ end
58
55
end
59
56
60
- # Forward Solve
61
- sol = solve! (cache, alg, args... ; kwargs... )
57
+ sol = solve! (cache)
58
+
59
+ function ∇linear_solve (∂sol)
60
+ ∂∅ = NoTangent ()
62
61
63
- function ∇solve! (∂sol)
64
- @assert ! cache. isfresh " `cache.A` has been updated between the forward and the \
65
- reverse pass. This is not supported."
66
62
∂u = ∂sol. u
67
- if sensealg. linsolve === nothing
63
+ if sensealg. linsolve === missing
68
64
λ = if cache. cacheval isa Factorization
69
65
cache. cacheval' \ ∂u
70
66
elseif cache. cacheval isa Tuple && cache. cacheval[1 ] isa Factorization
@@ -79,25 +75,23 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;
79
75
solve (invprob, alg; cache. abstol, cache. reltol, cache. verbose). u
80
76
end
81
77
else
82
- error (" Not Implemented Yet!!!" )
78
+ invprob = LinearProblem (transpose (A_), ∂u) # We cached `A`
79
+ λ = solve (
80
+ invprob, sensealg. linsolve; cache. abstol, cache. reltol, cache. verbose). u
83
81
end
84
82
85
83
∂A = - λ * transpose (sol. u)
86
84
∂b = λ
87
- ∂∅ = NoTangent ()
88
-
89
- ∂cache = LinearCache (∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache. isfresh, ∂∅, ∂∅, cache. abstol,
90
- cache. reltol, cache. maxiters, cache. verbose, cache. assumptions, cache. sensealg)
85
+ ∂prob = LinearProblem (∂A, ∂b, ∂∅)
91
86
92
- return (∂∅, ∂cache , ∂∅, ntuple (_ -> ∂∅, length (args))... )
87
+ return (∂∅, ∂prob , ∂∅, ntuple (_ -> ∂∅, length (args))... )
93
88
end
94
- return sol, ∇solve!
89
+
90
+ return sol, ∇linear_solve
95
91
end
96
92
97
93
function CRC. rrule (:: Type{<:LinearProblem} , A, b, p; kwargs... )
98
94
prob = LinearProblem (A, b, p)
99
- function ∇prob (∂prob)
100
- return NoTangent (), ∂prob. A, ∂prob. b, ∂prob. p
101
- end
95
+ ∇prob (∂prob) = (NoTangent (), ∂prob. A, ∂prob. b, ∂prob. p)
102
96
return prob, ∇prob
103
97
end
0 commit comments