Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ext/TensorKitChainRulesCoreExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
return adjoint(A), adjoint_pullback
end

function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool=false)
tA = twist(A, is; inv)
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv=!inv), NoTangent()
return tA, twist_pullback

Check warning on line 83 in ext/TensorKitChainRulesCoreExt/linalg.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/linalg.jl#L80-L83

Added lines #L80 - L83 were not covered by tests
end

function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
return dot(a, b), dot_pullback
Expand Down
191 changes: 103 additions & 88 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,18 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
ℂ[SU2Irrep](0 => 1, 1 => 1),
ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)',
ℂ[SU2Irrep](1 // 2 => 2),
ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)'))
ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)'),
(ℂ[FibonacciAnyon](:I => 1, :τ => 1),
ℂ[FibonacciAnyon](:I => 1, :τ => 2)',
ℂ[FibonacciAnyon](:I => 3, :τ => 2)',
ℂ[FibonacciAnyon](:I => 2, :τ => 3),
ℂ[FibonacciAnyon](:I => 2, :τ => 2)))

@timedtestset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in
Vlist
eltypes = isreal(sectortype(eltype(V))) ? (Float64, ComplexF64) : (ComplexF64,)
symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding

@timedtestset "Basic utility" begin
T1 = randn(Float64, V[1] ⊗ V[2] ← V[3] ⊗ V[4])
T2 = randn(ComplexF64, V[1] ⊗ V[2] ← V[3] ⊗ V[4])
Expand All @@ -137,14 +145,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(copy, T1)
test_rrule(copy, T2)
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))
if symmetricbraiding
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))

test_rrule(convert, Array, T1)
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
fkwargs=(; tol=Inf))
test_rrule(convert, Array, T1)
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
fkwargs=(; tol=Inf))
end
end

@timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64)
@timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
B = randn(T, space(A))

Expand All @@ -162,14 +172,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
C = randn(T, domain(A), codomain(A))
test_rrule(*, A, C)

test_rrule(permute, A, ((1, 3, 2), (5, 4)))
symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4)))
test_rrule(twist, A, 1)
test_rrule(twist, A, [1, 3])

D = randn(T, V[1] ⊗ V[2] ← V[3])
E = randn(T, V[4] ← V[5])
test_rrule(⊗, D, E)
symmetricbraiding && test_rrule(⊗, D, E)
end

@timedtestset "Linear Algebra part II with scalartype $T" for T in (Float64, ComplexF64)
@timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes
for i in 1:3
E = randn(T, ⊗(V[1:i]...) ← ⊗(V[1:i]...))
test_rrule(LinearAlgebra.tr, E)
Expand All @@ -184,97 +196,100 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(LinearAlgebra.dot, A, B)
end

@timedtestset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64)
atol = precision(T)
rtol = precision(T)

@timedtestset "tensortrace!" begin
for _ in 1:5
k1 = rand(0:3)
k2 = k1 == 3 ? 1 : rand(1:2)
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))

(_p, _q) = randindextuple(k1 + 2 * k2, k1)
p = _repartition(_p, rand(0:k1))
q = _repartition(_q, k2)
ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2)))
A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip))
symmetricbraiding &&
@timedtestset "TensorOperations with scalartype $T" for T in eltypes
atol = precision(T)
rtol = precision(T)

α = randn(T)
β = randn(T)
for conjA in (false, true)
C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, Val(false)))
test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol)
@timedtestset "tensortrace!" begin
for _ in 1:5
k1 = rand(0:3)
k2 = k1 == 3 ? 1 : rand(1:2)
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))

(_p, _q) = randindextuple(k1 + 2 * k2, k1)
p = _repartition(_p, rand(0:k1))
q = _repartition(_q, k2)
ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2)))
A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip))

α = randn(T)
β = randn(T)
for conjA in (false, true)
C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA,
Val(false)))
test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol)
end
end
end
end

@timedtestset "tensoradd!" begin
A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[4] ⊗ V[5])
α = randn(T)
β = randn(T)

# repeat a couple times to get some distribution of arrows
for _ in 1:5
p = randindextuple(length(V))
@timedtestset "tensoradd!" begin
A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[4] ⊗ V[5])
α = randn(T)
β = randn(T)

C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false)))
test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol)
# repeat a couple times to get some distribution of arrows
for _ in 1:5
p = randindextuple(length(V))

C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false)))
test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol)
C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false,
Val(false)))
test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol)

A = rand(Bool) ? C1 : C2
end
end
C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false)))
test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol)

@timedtestset "tensorcontract!" begin
for _ in 1:5
d = 0
local V1, V2, V3
# retry a couple times to make sure there are at least some nonzero elements
for _ in 1:10
k1 = rand(0:3)
k2 = rand(0:2)
k3 = rand(0:2)
V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1]))
V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1]))
V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1]))
d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3))
d > 0 && break
A = rand(Bool) ? C1 : C2
end
ipA = randindextuple(length(V1) + length(V2))
pA = _repartition(invperm(linearize(ipA)), length(V1))
ipB = randindextuple(length(V2) + length(V3))
pB = _repartition(invperm(linearize(ipB)), length(V2))
pAB = randindextuple(length(V1) + length(V3))
end

α = randn(T)
β = randn(T)
V2_conj = prod(conj, V2; init=one(V[1]))

for conjA in (false, true), conjB in (false, true)
A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA))
B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB))
C = randn!(TensorOperations.tensoralloc_contract(T, A, pA,
conjA,
B, pB, conjB, pAB,
Val(false)))
test_rrule(tensorcontract!, C,
A, pA, conjA, B, pB, conjB, pAB,
α, β; atol, rtol)
@timedtestset "tensorcontract!" begin
for _ in 1:5
d = 0
local V1, V2, V3
# retry a couple times to make sure there are at least some nonzero elements
for _ in 1:10
k1 = rand(0:3)
k2 = rand(0:2)
k3 = rand(0:2)
V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1]))
V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1]))
V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1]))
d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3))
d > 0 && break
end
ipA = randindextuple(length(V1) + length(V2))
pA = _repartition(invperm(linearize(ipA)), length(V1))
ipB = randindextuple(length(V2) + length(V3))
pB = _repartition(invperm(linearize(ipB)), length(V2))
pAB = randindextuple(length(V1) + length(V3))

α = randn(T)
β = randn(T)
V2_conj = prod(conj, V2; init=one(V[1]))

for conjA in (false, true), conjB in (false, true)
A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA))
B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB))
C = randn!(TensorOperations.tensoralloc_contract(T, A, pA,
conjA,
B, pB, conjB, pAB,
Val(false)))
test_rrule(tensorcontract!, C,
A, pA, conjA, B, pB, conjB, pAB,
α, β; atol, rtol)
end
end
end
end

@timedtestset "tensorscalar" begin
A = randn(T, ProductSpace{typeof(V[1]),0}())
test_rrule(tensorscalar, A)
@timedtestset "tensorscalar" begin
A = randn(T, ProductSpace{typeof(V[1]),0}())
test_rrule(tensorscalar, A)
end
end
end

@timedtestset "Factorizations with scalartype $T" for T in (Float64, ComplexF64)
@timedtestset "Factorizations with scalartype $T" for T in eltypes
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
B = randn(T, space(A)')
C = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])
Expand Down Expand Up @@ -367,13 +382,13 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),

c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])),
blocks(S))
U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c)))
trunc = truncdim(round(Int, 2 * dim(c)))
U, S, V, ϵ = tsvd(C; trunc)
ΔU = randn(scalartype(U), space(U))
ΔS = randn(scalartype(S), space(S))
ΔV = randn(scalartype(V), space(V))
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
fkwargs=(; trunc=truncdim(2 * dim(c))))
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), fkwargs=(; trunc))
end

let D = LinearAlgebra.eigvals(C)
Expand Down
Loading