Skip to content

Commit 06c09a3

Browse files
committed
Finish part of the implementation
1 parent 7e61692 commit 06c09a3

File tree

1 file changed

+65
-9
lines changed

1 file changed

+65
-9
lines changed

src/adjoint.jl

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
2-
# TODO: Document the options in LinearSolveAdjoint
32

43
@doc doc"""
54
LinearSolveAdjoint(; linsolve = nothing)
@@ -29,19 +28,76 @@ specific structure distinct from ``A`` then passing in a `linsolve` will be more
2928
linsolve::L = nothing
3029
end
3130

32-
CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...)
31+
function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem,
32+
alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
33+
cache = init(prob, alg, args...; kwargs...)
34+
function ∇init(∂cache)
35+
∂∅ = NoTangent()
36+
∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p)
37+
∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p)
38+
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)
39+
end
40+
return cache, ∇init
41+
end
3342

34-
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache)
35-
sensealg = cache.sensealg
43+
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;
44+
kwargs...)
45+
(; A, b, sensealg) = cache
3646

37-
# Decide if we need to cache the
47+
# Decide if we need to cache `A` and `b` for the reverse pass
48+
if sensealg.linsolve === nothing
49+
# We can reuse the factorization so no copy is needed
50+
# Krylov Methods don't modify `A`, so it's safe to just reuse it
51+
# No Copy is needed even for the default case
52+
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
53+
alg isa DefaultLinearSolver)
54+
A_ = cache.alias_A ? deepcopy(A) : A
55+
end
56+
else
57+
error("Not Implemented Yet!!!")
58+
end
59+
60+
# Forward Solve
61+
sol = solve!(cache, alg, args...; kwargs...)
3862

39-
sol = solve!(cache)
4063
function ∇solve!(∂sol)
41-
@assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported."
64+
@assert !cache.isfresh "`cache.A` has been updated between the forward and the \
65+
reverse pass. This is not supported."
66+
∂u = ∂sol.u
67+
if sensealg.linsolve === nothing
68+
λ = if cache.cacheval isa Factorization
69+
cache.cacheval' \ ∂u
70+
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
71+
first(cache.cacheval)' \ ∂u
72+
elseif alg isa AbstractKrylovSubspaceMethod
73+
invprob = LinearProblem(transpose(cache.A), ∂u)
74+
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
75+
elseif alg isa DefaultLinearSolver
76+
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
77+
else
78+
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
79+
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
80+
end
81+
else
82+
error("Not Implemented Yet!!!")
83+
end
84+
85+
∂A = -λ * transpose(sol.u)
86+
∂b = λ
87+
∂∅ = NoTangent()
4288

43-
∂cache = NoTangent()
44-
return NoTangent(), ∂cache
89+
∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol,
90+
cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg)
91+
92+
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
4593
end
4694
return sol, ∇solve!
4795
end
96+
97+
function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
98+
prob = LinearProblem(A, b, p)
99+
function ∇prob(∂prob)
100+
return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p
101+
end
102+
return prob, ∇prob
103+
end

0 commit comments

Comments
 (0)