Skip to content

Commit 33bca87

Browse files
ebelnikolalkdvosJutho
authored
DiagonalTensorMap constructor rrule (#208)
* adds `ProjectTo` for `DiagonalTensorMap` * adds an `rrule` for `DiagonalTensorMap` constructor * Corrects bug in the DiagonalTensorMap rrule, adds tests for the new code, adds a proper generator of random tangents for DiagonalTensorMap * @test missing in the constructor test added... * wait no, @test did not belong there * Update ext/TensorKitChainRulesCoreExt/utility.jl Co-authored-by: Lukas Devos <[email protected]> * mixed type tests for ProjectTo * + rrule test on complex tensors. * correct data length for DiagonalTensor in tests * correct data length in DiagonalTensorMap for random tnagents * Comment on the test failure * Jutho's corrections * Add `DiagonalTensorMap(::AbstractTensorMap)` * Specialize `to_vec(::DiagonalTensorMap)` * Add rrules matrix functions * Add tests AD of matrixfunctions * Remove duplicate methods * disable broken tests * Fix CI check * Adapt rrules for constructors and getproperty to include qdims * exchange sqrt and invsqrt in hope of fixing without thinking * Actually think to fix the problem * Simplify positive data generation * simplify CI detection * Fix bad merge * uncomment non-ad tests --------- Co-authored-by: Lukas Devos <[email protected]> Co-authored-by: Jutho <[email protected]>
1 parent 33f10bc commit 33bca87

File tree

7 files changed

+176
-3
lines changed

7 files changed

+176
-3
lines changed

dev/KrylovKit

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 8bccac88a9474b47ce49bac72ae19b3806ce129f

ext/TensorKitChainRulesCoreExt/constructors.jl

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,86 @@
44
@non_differentiable TensorKit.isometry(args...)
55
@non_differentiable TensorKit.unitary(args...)
66

7-
function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...)
7+
function ChainRulesCore.rrule(::Type{TensorMap}, d::DenseArray, args...; kwargs...)
88
function TensorMap_pullback(Δt)
99
∂d = convert(Array, unthunk(Δt))
1010
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
1111
end
1212
return TensorMap(d, args...; kwargs...), TensorMap_pullback
1313
end
1414

15+
# these are not the conversion to/from array, but actually take in data parameters
16+
# -- as a result, requires quantum dimensions to keep inner product the same:
17+
# ⟨Δdata, ∂data⟩ = ⟨Δtensor, ∂tensor⟩ = ∑_c d_c ⟨Δtensor_c, ∂tensor_c⟩
18+
# ⟹ Δdata = d_c Δtensor_c
19+
function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector,
20+
V::TensorMapSpace) where {T}
21+
t = TensorMap{T}(data, V)
22+
P = ProjectTo(data)
23+
function TensorMap_pullback(Δt_)
24+
Δt = copy(unthunk(Δt_))
25+
for (c, b) in blocks(Δt)
26+
scale!(b, dim(c))
27+
end
28+
∂data = P(Δt.data)
29+
return NoTangent(), ∂data, NoTangent()
30+
end
31+
return t, TensorMap_pullback
32+
end
33+
34+
function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, args...;
35+
kwargs...)
36+
D = DiagonalTensorMap(data, args...; kwargs...)
37+
P = ProjectTo(data)
38+
function DiagonalTensorMap_pullback(Δt_)
39+
# unclear if we're allowed to modify/take ownership of the input
40+
Δt = copy(unthunk(Δt_))
41+
for (c, b) in blocks(Δt)
42+
scale!(b, dim(c))
43+
end
44+
∂data = P(Δt.data)
45+
return NoTangent(), ∂data, NoTangent()
46+
end
47+
return D, DiagonalTensorMap_pullback
48+
end
49+
50+
function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Symbol)
51+
if prop === :data
52+
function getdata_pullback(Δdata)
53+
# unclear if we're allowed to modify/take ownership of the input
54+
t′ = typeof(t)(copy(unthunk(Δdata)), t.space)
55+
for (c, b) in blocks(t′)
56+
scale!(b, inv(dim(c)))
57+
end
58+
return NoTangent(), t′, NoTangent()
59+
end
60+
return t.data, getdata_pullback
61+
elseif prop === :space
62+
return t.space, Returns((NoTangent(), ZeroTangent(), NoTangent()))
63+
else
64+
throw(ArgumentError("unknown property $prop"))
65+
end
66+
end
67+
68+
function ChainRulesCore.rrule(::typeof(Base.getproperty), t::DiagonalTensorMap,
69+
prop::Symbol)
70+
if prop === :data
71+
function getdata_pullback(Δdata)
72+
# unclear if we're allowed to modify/take ownership of the input
73+
t′ = typeof(t)(copy(unthunk(Δdata)), t.domain)
74+
for (c, b) in blocks(t′)
75+
scale!(b, inv(dim(c)))
76+
end
77+
return NoTangent(), t′, NoTangent()
78+
end
79+
return t.data, getdata_pullback
80+
elseif prop === :domain
81+
return t.domain, Returns((NoTangent(), ZeroTangent(), NoTangent()))
82+
else
83+
throw(ArgumentError("unknown property $prop"))
84+
end
85+
end
86+
1587
function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
1688
copy_pullback(Δt) = NoTangent(), Δt
1789
return copy(t), copy_pullback

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap)
113113
return a_imag, imag_pullback
114114
end
115115

116-
function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorMap)
116+
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(exp),
117+
A::AbstractTensorMap)
117118
domain(A) == codomain(A) ||
118119
error("Exponential of a tensor only exist when domain == codomain.")
119120
P_A = ProjectTo(A)
@@ -133,3 +134,21 @@ function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorM
133134
end
134135
return C, exp_pullback
135136
end
137+
138+
# define rrules for matrix functions for DiagonalTensorMap, since they access data directly.
139+
for f in
140+
(:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt,
141+
:log, :asin, :acos, :acosh, :atanh, :acoth)
142+
f_pullback = Symbol(f, :_pullback)
143+
@eval function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($f),
144+
t::DiagonalTensorMap)
145+
P = ProjectTo(t) # unsure if this is necessary, should already be in pullback
146+
d, pullback = rrule_via_ad(cfg, broadcast, $f, t.data)
147+
function $f_pullback(Δd_)
148+
Δd = P(unthunk(Δd_))
149+
_, _, ∂data = pullback(Δd.data)
150+
return NoTangent(), DiagonalTensorMap(∂data, t.domain)
151+
end
152+
return DiagonalTensorMap(d, t.domain), $f_pullback
153+
end
154+
end

ext/TensorKitChainRulesCoreExt/utility.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,15 @@ function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{<:Any,S,N
3232
end
3333
return y
3434
end
35+
36+
function (::ProjectTo{DiagonalTensorMap{T,S,A}})(x::AbstractTensorMap) where {T,S,A}
37+
x isa DiagonalTensorMap{T,S,A} && return x
38+
V = space(x, 1)
39+
space(x) == (V V) || throw(SpaceMismatch())
40+
y = DiagonalTensorMap{T,S,A}(undef, V)
41+
for (c, b) in blocks(y)
42+
p = ProjectTo(b)
43+
b .= p(block(x, c))
44+
end
45+
return y
46+
end

ext/TensorKitFiniteDifferencesExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ function FiniteDifferences.to_vec(t::AbstractTensorMap)
2323
end
2424
FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t))
2525

26+
function FiniteDifferences.to_vec(t::DiagonalTensorMap)
27+
x_vec, back = to_vec(TensorMap(t))
28+
function DiagonalTensorMap_from_vec(x_vec)
29+
return DiagonalTensorMap(back(x_vec))
30+
end
31+
return x_vec, DiagonalTensorMap_from_vec
32+
end
33+
2634
end
2735

2836
# TODO: Investigate why the approach below doesn't work

test/ad.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ChainRulesTestUtils
33
using FiniteDifferences: FiniteDifferences
44
using Random
55
using LinearAlgebra
6+
using Zygote
67

78
const _repartition = @static if isdefined(Base, :get_extension)
89
Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition
@@ -15,6 +16,10 @@ end
1516
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
1617
return randn!(similar(x))
1718
end
19+
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap)
20+
V = x.domain
21+
return DiagonalTensorMap(randn(eltype(x), reduceddim(V)), V)
22+
end
1823
ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent()
1924
function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap,
2025
expected::AbstractTensorMap, msg=""; kwargs...)
@@ -152,6 +157,46 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
152157
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
153158
fkwargs=(; tol=Inf))
154159
end
160+
161+
test_rrule(Base.getproperty, T1, :data)
162+
test_rrule(TensorMap{scalartype(T1)}, T1.data, T1.space)
163+
test_rrule(Base.getproperty, T2, :data)
164+
test_rrule(TensorMap{scalartype(T2)}, T2.data, T2.space)
165+
end
166+
167+
@timedtestset "Basic utility (DiagonalTensor)" begin
168+
for v in V
169+
rdim = reduceddim(v)
170+
D1 = DiagonalTensorMap(randn(rdim), v)
171+
D2 = DiagonalTensorMap(randn(rdim), v)
172+
D = D1 + im * D2
173+
T1 = TensorMap(D1)
174+
T2 = TensorMap(D2)
175+
T = T1 + im * T2
176+
177+
# real -> real
178+
P1 = ProjectTo(D1)
179+
@test P1(D1) == D1
180+
@test P1(T1) == D1
181+
182+
# complex -> complex
183+
P2 = ProjectTo(D)
184+
@test P2(D) == D
185+
@test P2(T) == D
186+
187+
# real -> complex
188+
@test P2(D1) == D1 + 0 * im * D1
189+
@test P2(T1) == D1 + 0 * im * D1
190+
191+
# complex -> real
192+
@test P1(D) == D1
193+
@test P1(T) == D1
194+
195+
test_rrule(DiagonalTensorMap, D1.data, D1.domain)
196+
test_rrule(DiagonalTensorMap, D.data, D.domain)
197+
test_rrule(Base.getproperty, D, :data)
198+
test_rrule(Base.getproperty, D1, :data)
199+
end
155200
end
156201

157202
@timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes
@@ -196,6 +241,21 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
196241
test_rrule(LinearAlgebra.dot, A, B)
197242
end
198243

244+
@timedtestset "Matrix functions ($T)" for T in eltypes
245+
for f in (sqrt, exp)
246+
check_inferred = false # !(T <: Real) # not type-stable for real functions
247+
t1 = randn(T, V[1] V[1])
248+
t2 = randn(T, V[2] V[2])
249+
d = DiagonalTensorMap{T}(undef, V[1])
250+
(T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data)
251+
d2 = DiagonalTensorMap{T}(undef, V[1])
252+
(T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data)
253+
test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred)
254+
test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred)
255+
test_rrule(f, d; check_inferred, output_tangent=d2)
256+
end
257+
end
258+
199259
symmetricbraiding &&
200260
@timedtestset "TensorOperations with scalartype $T" for T in eltypes
201261
atol = precision(T)

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ include("spaces.jl")
6060
include("tensors.jl")
6161
include("diagonal.jl")
6262
include("planar.jl")
63-
if !(Sys.isapple()) # TODO: remove once we know why this is so slow on macOS
63+
# TODO: remove once we know AD is slow on macOS CI
64+
if !(Sys.isapple() && get(ENV, "CI", "false") == "true")
6465
include("ad.jl")
6566
end
6667
include("bugfixes.jl")

0 commit comments

Comments
 (0)