From 5df707f6b5edbfc878e65cf0b830e19d2d33b6ac Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 5 Feb 2025 07:28:19 -0500 Subject: [PATCH] Add rrule for matrix exponential --- ext/TensorKitChainRulesCoreExt/linalg.jl | 21 +++++++++++++++++++++ test/ad.jl | 1 + 2 files changed, 22 insertions(+) diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index c13694673..1b254700e 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -106,3 +106,24 @@ function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap) end return a_imag, imag_pullback end + +function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorMap) + domain(A) == codomain(A) || + error("Exponential of a tensor only exist when domain == codomain.") + P_A = ProjectTo(A) + C = similar(A) + pullbacks = map(blocks(A)) do (c, b) + expB, pullback = rrule_via_ad(cfg, exp, b) + copy!(block(C, c), expB) + return c => pullback + end + function exp_pullback(ΔC_) + ΔC = unthunk(ΔC_) + dA = similar(A) + for (c, pb) in pullbacks + copy!(block(dA, c), last(pb(block(ΔC, c)))) + end + return NoTangent(), P_A(dA) + end + return C, exp_pullback +end diff --git a/test/ad.jl b/test/ad.jl index a684c4f83..96d8f28af 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -173,6 +173,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), for i in 1:3 E = randn(T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) test_rrule(LinearAlgebra.tr, E) + test_rrule(exp, E; check_inferred=false) end A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])