Skip to content

Commit 716ff70

Browse files
committed
Add tests
1 parent d0de984 commit 716ff70

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

test/ad.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
136136

137137
test_rrule(copy, T1)
138138
test_rrule(copy, T2)
139+
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
140+
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))
139141

140142
test_rrule(convert, Array, T1)
141143
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);

test/bugfixes.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,32 @@
4343
@test storagetype(t5) == Vector{Float64}
4444
tensorfree!(t2)
4545
end
46+
47+
@testset "Issue #201" begin
48+
function f(A::AbstractTensorMap)
49+
U, S, V, = tsvd(A)
50+
return tr(S)
51+
end
52+
function f(A::AbstractMatrix)
53+
S = LinearAlgebra.svdvals(A)
54+
return sum(S)
55+
end
56+
A₀ = randn(Z2Space(4, 4) Z2Space(4, 4))
57+
grad1, = Zygote.gradient(f, A₀)
58+
grad2, = Zygote.gradient(f, convert(Array, A₀))
59+
@test convert(Array, grad1) grad2
60+
61+
function g(A::AbstractTensorMap)
62+
U, S, V, = tsvd(A)
63+
return tr(U * V)
64+
end
65+
function g(A::AbstractMatrix)
66+
U, S, V, = LinearAlgebra.svd(A)
67+
return tr(U * V')
68+
end
69+
B₀ = randn(ComplexSpace(4) ComplexSpace(4))
70+
grad3, = Zygote.gradient(g, B₀)
71+
grad4, = Zygote.gradient(g, convert(Array, B₀))
72+
@test convert(Array, grad3) grad4
73+
end
4674
end

0 commit comments

Comments
 (0)