diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index 1b254700e..922900b6f 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -77,6 +77,12 @@ function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) return adjoint(A), adjoint_pullback end +function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool=false) + tA = twist(A, is; inv) + twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv=!inv), NoTangent() + return tA, twist_pullback +end + function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap) dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd) return dot(a, b), dot_pullback diff --git a/test/ad.jl b/test/ad.jl index 96d8f28af..93251eb03 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -122,10 +122,18 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), ℂ[SU2Irrep](0 => 1, 1 => 1), ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)', ℂ[SU2Irrep](1 // 2 => 2), - ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)')) + ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)'), + (ℂ[FibonacciAnyon](:I => 1, :τ => 1), + ℂ[FibonacciAnyon](:I => 1, :τ => 2)', + ℂ[FibonacciAnyon](:I => 3, :τ => 2)', + ℂ[FibonacciAnyon](:I => 2, :τ => 3), + ℂ[FibonacciAnyon](:I => 2, :τ => 2))) @timedtestset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in Vlist + eltypes = isreal(sectortype(eltype(V))) ? (Float64, ComplexF64) : (ComplexF64,) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + @timedtestset "Basic utility" begin T1 = randn(Float64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) T2 = randn(ComplexF64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) @@ -137,14 +145,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(copy, T1) test_rrule(copy, T2) test_rrule(TensorKit.copy_oftype, T1, ComplexF64) - test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4))) + if symmetricbraiding + test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4))) - test_rrule(convert, Array, T1) - test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); - fkwargs=(; tol=Inf)) + test_rrule(convert, Array, T1) + test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); + fkwargs=(; tol=Inf)) + end end - @timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64) + @timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = randn(T, space(A)) @@ -162,14 +172,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), C = randn(T, domain(A), codomain(A)) test_rrule(*, A, C) - test_rrule(permute, A, ((1, 3, 2), (5, 4))) + symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4))) + test_rrule(twist, A, 1) + test_rrule(twist, A, [1, 3]) D = randn(T, V[1] ⊗ V[2] ← V[3]) E = randn(T, V[4] ← V[5]) - test_rrule(⊗, D, E) + symmetricbraiding && test_rrule(⊗, D, E) end - @timedtestset "Linear Algebra part II with scalartype $T" for T in (Float64, ComplexF64) + @timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes for i in 1:3 E = randn(T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) test_rrule(LinearAlgebra.tr, E) @@ -184,97 +196,100 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(LinearAlgebra.dot, A, B) end - @timedtestset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64) - atol = precision(T) - rtol = precision(T) - - @timedtestset "tensortrace!" begin - for _ in 1:5 - k1 = rand(0:3) - k2 = k1 == 3 ? 1 : rand(1:2) - V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) - V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) - - (_p, _q) = randindextuple(k1 + 2 * k2, k1) - p = _repartition(_p, rand(0:k1)) - q = _repartition(_q, k2) - ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) - A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + symmetricbraiding && + @timedtestset "TensorOperations with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) - α = randn(T) - β = randn(T) - for conjA in (false, true) - C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, Val(false))) - test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol) + @timedtestset "tensortrace!" begin + for _ in 1:5 + k1 = rand(0:3) + k2 = k1 == 3 ? 1 : rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + for conjA in (false, true) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, + Val(false))) + test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol) + end end end - end - @timedtestset "tensoradd!" begin - A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[4] ⊗ V[5]) - α = randn(T) - β = randn(T) - - # repeat a couple times to get some distribution of arrows - for _ in 1:5 - p = randindextuple(length(V)) + @timedtestset "tensoradd!" begin + A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) - C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) - test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol) + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randindextuple(length(V)) - C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false))) - test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol) + C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, + Val(false))) + test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol) - A = rand(Bool) ? C1 : C2 - end - end + C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false))) + test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol) - @timedtestset "tensorcontract!" begin - for _ in 1:5 - d = 0 - local V1, V2, V3 - # retry a couple times to make sure there are at least some nonzero elements - for _ in 1:10 - k1 = rand(0:3) - k2 = rand(0:2) - k3 = rand(0:2) - V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1])) - V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1])) - V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1])) - d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) - d > 0 && break + A = rand(Bool) ? C1 : C2 end - ipA = randindextuple(length(V1) + length(V2)) - pA = _repartition(invperm(linearize(ipA)), length(V1)) - ipB = randindextuple(length(V2) + length(V3)) - pB = _repartition(invperm(linearize(ipB)), length(V2)) - pAB = randindextuple(length(V1) + length(V3)) + end - α = randn(T) - β = randn(T) - V2_conj = prod(conj, V2; init=one(V[1])) - - for conjA in (false, true), conjB in (false, true) - A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) - B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) - C = randn!(TensorOperations.tensoralloc_contract(T, A, pA, - conjA, - B, pB, conjB, pAB, - Val(false))) - test_rrule(tensorcontract!, C, - A, pA, conjA, B, pB, conjB, pAB, - α, β; atol, rtol) + @timedtestset "tensorcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init=one(V[1])) + + for conjA in (false, true), conjB in (false, true) + A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) + B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) + C = randn!(TensorOperations.tensoralloc_contract(T, A, pA, + conjA, + B, pB, conjB, pAB, + Val(false))) + test_rrule(tensorcontract!, C, + A, pA, conjA, B, pB, conjB, pAB, + α, β; atol, rtol) + end end end - end - @timedtestset "tensorscalar" begin - A = randn(T, ProductSpace{typeof(V[1]),0}()) - test_rrule(tensorscalar, A) + @timedtestset "tensorscalar" begin + A = randn(T, ProductSpace{typeof(V[1]),0}()) + test_rrule(tensorscalar, A) + end end - end - @timedtestset "Factorizations with scalartype $T" for T in (Float64, ComplexF64) + @timedtestset "Factorizations with scalartype $T" for T in eltypes A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = randn(T, space(A)') C = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) @@ -367,13 +382,13 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) - U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c))) + trunc = truncdim(round(Int, 2 * dim(c))) + U, S, V, ϵ = tsvd(C; trunc) ΔU = randn(scalartype(U), space(U)) ΔS = randn(scalartype(S), space(S)) ΔV = randn(scalartype(V), space(V)) T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), - fkwargs=(; trunc=truncdim(2 * dim(c)))) + test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), fkwargs=(; trunc)) end let D = LinearAlgebra.eigvals(C)