|
1 | 1 | # TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
|
2 |
| -# TODO: Document the options in LinearSolveAdjoint |
3 | 2 |
|
4 | 3 | @doc doc"""
|
5 | 4 | LinearSolveAdjoint(; linsolve = nothing)
|
@@ -29,19 +28,76 @@ specific structure distinct from ``A`` then passing in a `linsolve` will be more
|
29 | 28 | linsolve::L = nothing
|
30 | 29 | end
|
31 | 30 |
|
32 |
| -CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...) |
| 31 | +function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem, |
| 32 | + alg::SciMLLinearSolveAlgorithm, args...; kwargs...) |
| 33 | + cache = init(prob, alg, args...; kwargs...) |
| 34 | + function ∇init(∂cache) |
| 35 | + ∂∅ = NoTangent() |
| 36 | + ∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p) |
| 37 | + ∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p) |
| 38 | + return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...) |
| 39 | + end |
| 40 | + return cache, ∇init |
| 41 | +end |
33 | 42 |
|
34 |
| -function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache) |
35 |
| - sensealg = cache.sensealg |
| 43 | +function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...; |
| 44 | + kwargs...) |
| 45 | + (; A, b, sensealg) = cache |
36 | 46 |
|
37 |
| - # Decide if we need to cache the |
| 47 | + # Decide if we need to cache `A` and `b` for the reverse pass |
| 48 | + if sensealg.linsolve === nothing |
| 49 | + # We can reuse the factorization so no copy is needed |
| 50 | + # Krylov Methods don't modify `A`, so it's safe to just reuse it |
| 51 | + # No Copy is needed even for the default case |
| 52 | + if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod || |
| 53 | + alg isa DefaultLinearSolver) |
| 54 | + A_ = cache.alias_A ? deepcopy(A) : A |
| 55 | + end |
| 56 | + else |
| 57 | + error("Not Implemented Yet!!!") |
| 58 | + end |
| 59 | + |
| 60 | + # Forward Solve |
| 61 | + sol = solve!(cache, alg, args...; kwargs...) |
38 | 62 |
|
39 |
| - sol = solve!(cache) |
40 | 63 | function ∇solve!(∂sol)
|
41 |
| - @assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported." |
| 64 | + @assert !cache.isfresh "`cache.A` has been updated between the forward and the \ |
| 65 | + reverse pass. This is not supported." |
| 66 | + ∂u = ∂sol.u |
| 67 | + if sensealg.linsolve === nothing |
| 68 | + λ = if cache.cacheval isa Factorization |
| 69 | + cache.cacheval' \ ∂u |
| 70 | + elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization |
| 71 | + first(cache.cacheval)' \ ∂u |
| 72 | + elseif alg isa AbstractKrylovSubspaceMethod |
| 73 | + invprob = LinearProblem(transpose(cache.A), ∂u) |
| 74 | + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u |
| 75 | + elseif alg isa DefaultLinearSolver |
| 76 | + LinearSolve.defaultalg_adjoint_eval(cache, ∂u) |
| 77 | + else |
| 78 | + invprob = LinearProblem(transpose(A_), ∂u) # We cached `A` |
| 79 | + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u |
| 80 | + end |
| 81 | + else |
| 82 | + error("Not Implemented Yet!!!") |
| 83 | + end |
| 84 | + |
| 85 | + ∂A = -λ * transpose(sol.u) |
| 86 | + ∂b = λ |
| 87 | + ∂∅ = NoTangent() |
42 | 88 |
|
43 |
| - ∂cache = NoTangent() |
44 |
| - return NoTangent(), ∂cache |
| 89 | + ∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol, |
| 90 | + cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg) |
| 91 | + |
| 92 | + return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) |
45 | 93 | end
|
46 | 94 | return sol, ∇solve!
|
47 | 95 | end
|
| 96 | + |
| 97 | +function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...) |
| 98 | + prob = LinearProblem(A, b, p) |
| 99 | + function ∇prob(∂prob) |
| 100 | + return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p |
| 101 | + end |
| 102 | + return prob, ∇prob |
| 103 | +end |
0 commit comments