Skip to content

Commit 3f10b81

Browse files
committed
Add tests
1 parent 5a48e78 commit 3f10b81

File tree

4 files changed

+34
-1
lines changed

4 files changed

+34
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Test = "1"
4141
TestExtras = "0.2,0.3"
4242
TupleTools = "1.1"
4343
VectorInterface = "0.4, 0.5"
44+
Zygote = "0.7"
4445
julia = "1.10"
4546

4647
[extras]
@@ -53,6 +54,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5354
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
5455
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5556
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
57+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5658

5759
[targets]
58-
test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences"]
60+
test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]

test/ad.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
136136

137137
test_rrule(copy, T1)
138138
test_rrule(copy, T2)
139+
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
140+
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))
139141

140142
test_rrule(convert, Array, T1)
141143
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);

test/bugfixes.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,32 @@
4343
@test storagetype(t5) == Vector{Float64}
4444
tensorfree!(t2)
4545
end
46+
47+
@testset "Issue #201" begin
48+
function f(A::AbstractTensorMap)
49+
U, S, V, = tsvd(A)
50+
return tr(S)
51+
end
52+
function f(A::AbstractMatrix)
53+
S = LinearAlgebra.svdvals(A)
54+
return sum(S)
55+
end
56+
A₀ = randn(Z2Space(4, 4) Z2Space(4, 4))
57+
grad1, = Zygote.gradient(f, A₀)
58+
grad2, = Zygote.gradient(f, convert(Array, A₀))
59+
@test convert(Array, grad1) grad2
60+
61+
function g(A::AbstractTensorMap)
62+
U, S, V, = tsvd(A)
63+
return tr(U * V)
64+
end
65+
function g(A::AbstractMatrix)
66+
U, S, V, = LinearAlgebra.svd(A)
67+
return tr(U * V')
68+
end
69+
B₀ = randn(ComplexSpace(4) ComplexSpace(4))
70+
grad3, = Zygote.gradient(g, B₀)
71+
grad4, = Zygote.gradient(g, convert(Array, B₀))
72+
@test convert(Array, grad3) grad4
73+
end
4674
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Base.Iterators: take, product
99
# using SUNRepresentations: SUNIrrep
1010
# const SU3Irrep = SUNIrrep{3}
1111
using LinearAlgebra: LinearAlgebra
12+
using Zygote: Zygote
1213

1314
const TK = TensorKit
1415

0 commit comments

Comments
 (0)