Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`

```math
\begin{align}
A^T \lambda &= \partial x \\
\partial A &= -\lambda x^T \\
A' \lambda &= \partial x \\
\partial A &= -\lambda x' \\
\partial b &= \lambda
\end{align}
```
Expand All @@ -20,7 +20,7 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
forward solve (this is done by keeping the linsolve as `missing`). For example, if the
forward solve was performed via a Factorization, then we can reuse the factorization for the
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
adjoint solve. However, for specific structured matrices if ``A'`` is known to have a
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
"""
@kwdef struct LinearSolveAdjoint{L} <:
Expand Down Expand Up @@ -62,21 +62,21 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(transpose(cache.A), ∂u)
invprob = LinearProblem(adjoint(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

tu = transpose(sol.u)
tu = adjoint(sol.u)
∂A = BroadcastArray(@~ .-(λ .* tu))
∂b = λ
∂prob = LinearProblem(∂A, ∂b, ∂∅)
Expand Down
7 changes: 7 additions & 0 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
@test dA ≈ dA2
@test db1 ≈ db12

# Test complex numbers
A = rand(n, n) + 1im*rand(n, n);
b1 = rand(n) + 1im*rand(n);

function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg)
Expand All @@ -66,6 +70,9 @@ db22 = FiniteDiff.finite_difference_gradient(
@test db1 ≈ db12
@test db2 ≈ db22

A = rand(n, n);
b1 = rand(n);

function f4(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))
Expand Down
Loading