Skip to content

Commit f8ca1fd

Browse files
committed
Add rrules matrix functions
1 parent ccb7c93 commit f8ca1fd

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,21 @@ function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap)
106106
end
107107
return a_imag, imag_pullback
108108
end
109+
110+
# define rrules for matrix functions for DiagonalTensorMap, since they access data directly.
111+
for f in
112+
(:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt,
113+
:log, :asin, :acos, :acosh, :atanh, :acoth)
114+
f_pullback = Symbol(f, :_pullback)
115+
@eval function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($f),
116+
t::DiagonalTensorMap)
117+
P = ProjectTo(t) # unsure if this is necessary, should already be in pullback
118+
d, pullback = rrule_via_ad(cfg, broadcast, $f, t.data)
119+
function $f_pullback(Δd_)
120+
Δd = P(unthunk(Δd_))
121+
_, _, ∂data = pullback(Δd.data)
122+
return NoTangent(), DiagonalTensorMap(∂data, t.domain)
123+
end
124+
return DiagonalTensorMap(d, t.domain), $f_pullback
125+
end
126+
end

0 commit comments

Comments
 (0)