diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index 67dc5a980..fe779bd56 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -97,6 +97,14 @@ function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) return n, norm_pullback end +function ChainRulesCore.rrule(::typeof(inv), A::AbstractTensorMap) + Ainv = inv(A) + inv_pullback = let Ainv = Ainv + inv_pullback(ΔAinv) = NoTangent(), -Ainv' * unthunk(ΔAinv) * Ainv' + end + return Ainv, inv_pullback +end + function ChainRulesCore.rrule(::typeof(real), a::AbstractTensorMap) a_real = real(a) real_pullback(Δa) = NoTangent(), eltype(a) <: Real ? Δa : complex(unthunk(Δa)) diff --git a/test/ad.jl b/test/ad.jl index 496ff9e45..dd1f9823a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -234,6 +234,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), E = randn(T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) test_rrule(LinearAlgebra.tr, E) test_rrule(exp, E; check_inferred=false) + test_rrule(inv, E) end A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])