diff --git a/src/adjoint.jl b/src/adjoint.jl index f5034a736..4781602fb 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -7,8 +7,8 @@ Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b` ```math \begin{align} -A^T \lambda &= \partial x \\ -\partial A &= -\lambda x^T \\ +A' \lambda &= \partial x \\ +\partial A &= -\lambda x' \\ \partial b &= \lambda \end{align} ``` @@ -20,7 +20,7 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi Note that in most cases, it makes sense to use the same linear solver for the adjoint as the forward solve (this is done by keeping the linsolve as `missing`). For example, if the forward solve was performed via a Factorization, then we can reuse the factorization for the -adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a +adjoint solve. However, for specific structured matrices if ``A'`` is known to have a specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient. """ @kwdef struct LinearSolveAdjoint{L} <: @@ -62,21 +62,21 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization first(cache.cacheval)' \ ∂u elseif alg isa AbstractKrylovSubspaceMethod - invprob = LinearProblem(transpose(cache.A), ∂u) + invprob = LinearProblem(adjoint(cache.A), ∂u) solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u elseif alg isa DefaultLinearSolver LinearSolve.defaultalg_adjoint_eval(cache, ∂u) else - invprob = LinearProblem(transpose(A_), ∂u) # We cached `A` + invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u end else - invprob = LinearProblem(transpose(A_), ∂u) # We cached `A` + invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` λ = solve( invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u end - tu = transpose(sol.u) + tu = adjoint(sol.u) ∂A = BroadcastArray(@~ .-(λ .* tu)) ∂b = λ ∂prob = LinearProblem(∂A, ∂b, ∂∅) diff --git a/test/adjoint.jl b/test/adjoint.jl index e1c18ec0b..b31c447e8 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -44,6 +44,10 @@ db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 +# Test complex numbers +A = rand(n, n) + 1im*rand(n, n); +b1 = rand(n) + 1im*rand(n); + function f3(A, b1, b2; alg = KrylovJL_GMRES()) prob = LinearProblem(A, b1) sol1 = solve(prob, alg) @@ -66,6 +70,9 @@ db22 = FiniteDiff.finite_difference_gradient( @test db1 ≈ db12 @test db2 ≈ db22 +A = rand(n, n); +b1 = rand(n); + function f4(A, b1, b2; alg = LUFactorization()) prob = LinearProblem(A, b1) sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))