Skip to content

Commit 40e74d7

Browse files
authored
Add and test rrules for real and imag (#183)
1 parent 6387f26 commit 40e74d7

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,19 @@ function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2)
9090
end
9191
return n, norm_pullback
9292
end
93+
94+
function ChainRulesCore.rrule(::typeof(real), a::AbstractTensorMap)
95+
a_real = real(a)
96+
real_pullback(Δa) = NoTangent(), eltype(a) <: Real ? Δa : complex(unthunk(Δa))
97+
return a_real, real_pullback
98+
end
99+
100+
function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap)
101+
a_imag = imag(a)
102+
function imag_pullback(Δa)
103+
Δa′ = unthunk(Δa)
104+
return NoTangent(),
105+
eltype(a) <: Real ? ZeroTangent() : complex(zerovector(Δa′), Δa′)
106+
end
107+
return a_imag, imag_pullback
108+
end

test/ad.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
146146
A = randn(T, V[1] V[2] V[3] V[4] V[5])
147147
B = randn(T, space(A))
148148

149+
test_rrule(real, A)
150+
test_rrule(imag, A)
151+
149152
test_rrule(+, A, B)
150153
test_rrule(-, A)
151154
test_rrule(-, A, B)

0 commit comments

Comments
 (0)