Skip to content

Commit f09e40c

Browse files
authored
Add rrule for matrix exponential (#214)
1 parent a0e3460 commit f09e40c

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,24 @@ function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap)
106106
end
107107
return a_imag, imag_pullback
108108
end
109+
110+
function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorMap)
111+
domain(A) == codomain(A) ||
112+
error("Exponential of a tensor only exist when domain == codomain.")
113+
P_A = ProjectTo(A)
114+
C = similar(A)
115+
pullbacks = map(blocks(A)) do (c, b)
116+
expB, pullback = rrule_via_ad(cfg, exp, b)
117+
copy!(block(C, c), expB)
118+
return c => pullback
119+
end
120+
function exp_pullback(ΔC_)
121+
ΔC = unthunk(ΔC_)
122+
dA = similar(A)
123+
for (c, pb) in pullbacks
124+
copy!(block(dA, c), last(pb(block(ΔC, c))))
125+
end
126+
return NoTangent(), P_A(dA)
127+
end
128+
return C, exp_pullback
129+
end

test/ad.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
173173
for i in 1:3
174174
E = randn(T, (V[1:i]...) (V[1:i]...))
175175
test_rrule(LinearAlgebra.tr, E)
176+
test_rrule(exp, E; check_inferred=false)
176177
end
177178

178179
A = randn(T, V[1] V[2] V[3] V[4] V[5])

0 commit comments

Comments
 (0)