Skip to content

Commit 8f8a280

Browse files
lkdvosJutho
andauthored
rrule for copy_oftype, permutedcopy_oftype (#202)
* Add `rrule` for `copy_oftype` and `permutedcopy_oftype` * Consistently use `copy_oftype` * Add `unthunk` in `tsvd!_pullback` * Add tests * fix typo * fix permutedcopy_oftype rrule * final fix attempt without running locally * Formatter [no ci] --------- Co-authored-by: Jutho <[email protected]>
1 parent 4b1b265 commit 8f8a280

File tree

7 files changed

+63
-11
lines changed

7 files changed

+63
-11
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Test = "1"
4141
TestExtras = "0.2,0.3"
4242
TupleTools = "1.1"
4343
VectorInterface = "0.4, 0.5"
44+
Zygote = "0.7"
4445
julia = "1.10"
4546

4647
[extras]
@@ -53,6 +54,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5354
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
5455
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5556
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
57+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5658

5759
[targets]
58-
test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences"]
60+
test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]

ext/TensorKitChainRulesCoreExt/constructors.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
1717
return copy(t), copy_pullback
1818
end
1919

20+
function ChainRulesCore.rrule(::typeof(TensorKit.copy_oftype), t::AbstractTensorMap,
21+
T::Type{<:Number})
22+
project = ProjectTo(t)
23+
copy_oftype_pullback(Δt) = NoTangent(), project(unthunk(Δt)), NoTangent()
24+
return TensorKit.copy_oftype(t, T), copy_oftype_pullback
25+
end
26+
27+
function ChainRulesCore.rrule(::typeof(TensorKit.permutedcopy_oftype), t::AbstractTensorMap,
28+
T::Type{<:Number}, p::Index2Tuple)
29+
project = ProjectTo(t)
30+
function permutedcopy_oftype_pullback(Δt)
31+
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), t)
32+
return NoTangent(), project(TensorKit.permute(unthunk(Δt), invp)), NoTangent(),
33+
NoTangent()
34+
end
35+
return TensorKit.permutedcopy_oftype(t, T, p), permutedcopy_oftype_pullback
36+
end
37+
2038
function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},
2139
t::AbstractTensorMap)
2240
A = convert(T, t)

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
2020
Ũ, Σ̃, Ṽ⁺ = U, Σ, V⁺
2121
end
2222

23-
function tsvd!_pullback((ΔU, ΔΣ, ΔV⁺, Δϵ))
23+
function tsvd!_pullback(ΔUSVϵ)
24+
ΔU, ΔΣ, ΔV⁺, = unthunk.(ΔUSVϵ)
2425
Δt = similar(t)
2526
for (c, b) in blocks(Δt)
2627
Uc, Σc, V⁺c = block(U, c), block(Σ, c), block(V⁺, c)

src/tensors/factorizations.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -269,39 +269,39 @@ function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
269269
end
270270

271271
function tsvd(t::AbstractTensorMap; kwargs...)
272-
tcopy = copy!(similar(t, float(scalartype(t))), t)
272+
tcopy = copy_oftype(t, float(scalartype(t)))
273273
return tsvd!(tcopy; kwargs...)
274274
end
275275
function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...)
276-
tcopy = copy!(similar(t, float(scalartype(t))), t)
276+
tcopy = copy_oftype(t, float(scalartype(t)))
277277
return leftorth!(tcopy; alg=alg, kwargs...)
278278
end
279279
function rightorth(t::AbstractTensorMap; alg::OFA=LQpos(), kwargs...)
280-
tcopy = copy!(similar(t, float(scalartype(t))), t)
280+
tcopy = copy_oftype(t, float(scalartype(t)))
281281
return rightorth!(tcopy; alg=alg, kwargs...)
282282
end
283283
function leftnull(t::AbstractTensorMap; alg::OFA=QR(), kwargs...)
284-
tcopy = copy!(similar(t, float(scalartype(t))), t)
284+
tcopy = copy_oftype(t, float(scalartype(t)))
285285
return leftnull!(tcopy; alg=alg, kwargs...)
286286
end
287287
function rightnull(t::AbstractTensorMap; alg::OFA=LQ(), kwargs...)
288-
tcopy = copy!(similar(t, float(scalartype(t))), t)
288+
tcopy = copy_oftype(t, float(scalartype(t)))
289289
return rightnull!(tcopy; alg=alg, kwargs...)
290290
end
291291
function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...)
292-
tcopy = copy!(similar(t, float(scalartype(t))), t)
292+
tcopy = copy_oftype(t, float(scalartype(t)))
293293
return eigen!(tcopy; kwargs...)
294294
end
295295
function eig(t::AbstractTensorMap; kwargs...)
296-
tcopy = copy!(similar(t, float(scalartype(t))), t)
296+
tcopy = copy_oftype(t, float(scalartype(t)))
297297
return eig!(tcopy; kwargs...)
298298
end
299299
function eigh(t::AbstractTensorMap; kwargs...)
300-
tcopy = copy!(similar(t, float(scalartype(t))), t)
300+
tcopy = copy_oftype(t, float(scalartype(t)))
301301
return eigh!(tcopy; kwargs...)
302302
end
303303
function LinearAlgebra.isposdef(t::AbstractTensorMap)
304-
tcopy = copy!(similar(t, float(scalartype(t))), t)
304+
tcopy = copy_oftype(t, float(scalartype(t)))
305305
return isposdef!(tcopy)
306306
end
307307

test/ad.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
136136

137137
test_rrule(copy, T1)
138138
test_rrule(copy, T2)
139+
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
140+
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))
139141

140142
test_rrule(convert, Array, T1)
141143
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);

test/bugfixes.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,32 @@
4343
@test storagetype(t5) == Vector{Float64}
4444
tensorfree!(t2)
4545
end
46+
47+
@testset "Issue #201" begin
48+
function f(A::AbstractTensorMap)
49+
U, S, V, = tsvd(A)
50+
return tr(S)
51+
end
52+
function f(A::AbstractMatrix)
53+
S = LinearAlgebra.svdvals(A)
54+
return sum(S)
55+
end
56+
A₀ = randn(Z2Space(4, 4) Z2Space(4, 4))
57+
grad1, = Zygote.gradient(f, A₀)
58+
grad2, = Zygote.gradient(f, convert(Array, A₀))
59+
@test convert(Array, grad1) grad2
60+
61+
function g(A::AbstractTensorMap)
62+
U, S, V, = tsvd(A)
63+
return tr(U * V)
64+
end
65+
function g(A::AbstractMatrix)
66+
U, S, V, = LinearAlgebra.svd(A)
67+
return tr(U * V')
68+
end
69+
B₀ = randn(ComplexSpace(4) ComplexSpace(4))
70+
grad3, = Zygote.gradient(g, B₀)
71+
grad4, = Zygote.gradient(g, convert(Array, B₀))
72+
@test convert(Array, grad3) grad4
73+
end
4674
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Base.Iterators: take, product
99
# using SUNRepresentations: SUNIrrep
1010
# const SU3Irrep = SUNIrrep{3}
1111
using LinearAlgebra: LinearAlgebra
12+
using Zygote: Zygote
1213

1314
const TK = TensorKit
1415

0 commit comments

Comments
 (0)