From c4de4ac4ece3365e055771147f6190a578c8ee58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 4 Aug 2025 12:17:51 -0400 Subject: [PATCH 1/3] remove label_dest --- src/contract/allocate_output.jl | 13 +-- src/contract/blockedperms.jl | 19 +--- src/contract/contract.jl | 48 ++------ src/contract/contract_matricize/contract.jl | 8 +- test/test_basics.jl | 116 ++++++-------------- 5 files changed, 55 insertions(+), 149 deletions(-) diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index bfaa7ef..bedbfe7 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -7,8 +7,8 @@ function check_input(::typeof(contract), a1, labels1, a2, labels2) throw(ArgumentError("Invalid permutation for right tensor")) end -function check_input(::typeof(contract), a_dest, labels_dest, a1, labels1, a2, labels2) - ndims(a_dest) == length(labels_dest) || +function check_input(::typeof(contract), a_dest, labels_out, a1, labels1, a2, labels2) + ndims(a_dest) == length(labels_out) || throw(ArgumentError("Invalid permutation for destination tensor")) return check_input(contract, a1, labels1, a2, labels2) end @@ -17,7 +17,6 @@ end # i.e. `ContractAdd`? function output_axes( ::typeof(contract), - biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, @@ -26,15 +25,15 @@ function output_axes( ) axes_codomain, axes_contracted = blocks(axes(a1)[biperm1]) axes_contracted2, axes_domain = blocks(axes(a2)[biperm2]) + biperm_out = blockedtrivialperm((length(biperm1[Block(1)]), length(biperm2[Block(2)]))) @assert axes_contracted == axes_contracted2 - return genperm((axes_codomain..., axes_domain...), invperm(Tuple(biperm_dest))) + return genperm((axes_codomain..., axes_domain...), Tuple(biperm_out)) end # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( ::typeof(contract), - biperm_dest::AbstractBlockPermutation, a1::AbstractArray, biperm1::AbstractBlockPermutation, a2::AbstractArray, @@ -42,8 +41,6 @@ function allocate_output( α::Number=one(Bool), ) check_input(contract, a1, biperm1, a2, biperm2) - blocklengths(biperm_dest) == (length(biperm1[Block(1)]), length(biperm2[Block(2)])) || - throw(ArgumentError("Invalid permutation for destination tensor")) - axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α) + axes_dest = output_axes(contract, a1, biperm1, a2, biperm2, α) return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest) end diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index 5f8d78c..b4f113b 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -1,36 +1,29 @@ using .BaseExtensions: BaseExtensions -function blockedperms( - f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2 -) - return blockedperms(f, dimnames_dest, dimnames1, dimnames2) +function blockedperms(f::typeof(contract), ::Algorithm, dimnames1, dimnames2) + return blockedperms(f, dimnames1, dimnames2) end # codomain <-- domain -function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) - dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2))) +function blockedperms(::typeof(contract), dimnames1, dimnames2) + dimnames = collect(Iterators.flatten((dimnames1, dimnames2))) for i in unique(dimnames) - count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels")) + count(==(i), dimnames) in (1, 2) || throw(ArgumentError("Invalid contraction labels")) end codomain = Tuple(setdiff(dimnames1, dimnames2)) contracted = Tuple(intersect(dimnames1, dimnames2)) domain = Tuple(setdiff(dimnames2, dimnames1)) - perm_codomain_dest = BaseExtensions.indexin(codomain, dimnames_dest) - perm_domain_dest = BaseExtensions.indexin(domain, dimnames_dest) - perm_codomain1 = BaseExtensions.indexin(codomain, dimnames1) perm_domain1 = BaseExtensions.indexin(contracted, dimnames1) perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) - permblocks_dest = (perm_codomain_dest, perm_domain_dest) - biperm_dest = blockedpermvcat(permblocks_dest...) permblocks1 = (perm_codomain1, perm_domain1) biperm1 = blockedpermvcat(permblocks1...) permblocks2 = (perm_codomain2, perm_domain2) biperm2 = blockedpermvcat(permblocks2...) - return biperm_dest, biperm1, biperm2 + return biperm1, biperm2 end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index 02665ca..5c8ab3a 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -13,7 +13,6 @@ default_contract_alg() = Matricize() function contract!( alg::Algorithm, a_dest::AbstractArray, - biperm_dest::AbstractBlockPermutation, a1::AbstractArray, biperm1::AbstractBlockPermutation, a2::AbstractArray, @@ -36,35 +35,8 @@ function contract( return contract(Algorithm(alg), a1, labels1, a2, labels2, α; kwargs...) end -function contract( - alg::Algorithm, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2, - α::Number=one(Bool); - kwargs..., -) - labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α; kwargs...) - return contract(alg, labels_dest, a1, labels1, a2, labels2, α; kwargs...), labels_dest -end - -function contract( - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2, - α::Number=one(Bool); - alg=default_contract_alg(), - kwargs..., -) - return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2, α; kwargs...) -end - function contract!( a_dest::AbstractArray, - labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, @@ -74,13 +46,12 @@ function contract!( alg=default_contract_alg(), kwargs..., ) - contract!(Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...) + contract!(Algorithm(alg), a_dest, a1, labels1, a2, labels2, α, β; kwargs...) return a_dest end function contract( alg::Algorithm, - labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, @@ -89,14 +60,14 @@ function contract( kwargs..., ) check_input(contract, a1, labels1, a2, labels2) - biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) - return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...) + biperm1, biperm2 = blockedperms(contract, labels1, labels2) + labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α; kwargs...) + return contract(alg, a1, biperm1, a2, biperm2, α; kwargs...), labels_dest end function contract!( alg::Algorithm, a_dest::AbstractArray, - labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, @@ -105,14 +76,13 @@ function contract!( β::Number; kwargs..., ) - check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2) - biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) - return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...) + check_input(contract, a1, labels1, a2, labels2) + biperm1, biperm2 = blockedperms(contract, labels1, labels2) + return contract!(alg, a_dest, a1, biperm1, a2, biperm2, α, β; kwargs...) end function contract( alg::Algorithm, - biperm_dest::AbstractBlockPermutation, a1::AbstractArray, biperm1::AbstractBlockPermutation, a2::AbstractArray, @@ -121,7 +91,7 @@ function contract( kwargs..., ) check_input(contract, a1, biperm1, a2, biperm2) - a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α) - contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...) + a_dest = allocate_output(contract, a1, biperm1, a2, biperm2, α) + contract!(alg, a_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...) return a_dest end diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 1bf6f70..c758ecc 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -3,7 +3,6 @@ using LinearAlgebra: mul! function contract!( ::Matricize, a_dest::AbstractArray, - biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, @@ -11,11 +10,12 @@ function contract!( α::Number, β::Number, ) - check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2) - a_dest_mat = matricize(a_dest, biperm_dest) + biperm_out = blockedtrivialperm((length(biperm1[Block(1)]), length(biperm2[Block(2)]))) + check_input(contract, a_dest, biperm_out, a1, biperm1, a2, biperm2) + a_dest_mat = matricize(a_dest, biperm_out) a1_mat = matricize(a1, biperm1) a2_mat = matricize(a2, biperm2) mul!(a_dest_mat, a1_mat, a2_mat, α, β) - unmatricize!(a_dest, a_dest_mat, biperm_dest) + unmatricize!(a_dest, a_dest_mat, biperm_out) return a_dest end diff --git a/test/test_basics.jl b/test/test_basics.jl index 26c52b9..51b7f50 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,3 +1,4 @@ +using Random: randn! using Test: @test, @test_broken, @test_throws, @testset using EllipsisNotation: var".." @@ -134,41 +135,33 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) elt_dest = promote_type(elt1, elt2) a1 = ones(elt1, (1, 1)) a2 = ones(elt2, (1, 1)) - a_dest = ones(elt_dest, (1, 1)) + a_dest = ones(elt_dest, (1, 1, 1)) @test_throws ArgumentError contract(a1, (1, 2, 4), a2, (2, 3)) @test_throws ArgumentError contract(a1, (1, 2), a2, (2, 3, 4)) - @test_throws ArgumentError contract((1, 3, 4), a1, (1, 2), a2, (2, 3)) - @test_throws ArgumentError contract((1, 3), a1, (1, 2), a2, (2, 4)) - @test_throws ArgumentError contract!(a_dest, (1, 3, 4), a1, (1, 2), a2, (2, 3)) + @test_throws ArgumentError contract!(a_dest, a1, (1, 2), a2, (2, 3)) dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i) - for (d1s, d2s, d_dests) in ( - ((1, 2), (1, 2), ()), - ((1, 2), (2, 1), ()), - ((1, 2), (2, 1, 3), (3,)), - ((1, 2, 3), (2, 1), (3,)), - ((1, 2), (2, 3), (1, 3)), - ((1, 2), (2, 3), (3, 1)), - ((2, 1), (2, 3), (3, 1)), - ((1, 2, 3), (2, 3, 4), (1, 4)), - ((1, 2, 3), (2, 3, 4), (4, 1)), - ((3, 2, 1), (4, 2, 3), (4, 1)), - ((1, 2, 3), (3, 4), (1, 2, 4)), - ((1, 2, 3), (3, 4), (4, 1, 2)), - ((1, 2, 3), (3, 4), (2, 4, 1)), - ((3, 1, 2), (3, 4), (2, 4, 1)), - ((3, 2, 1), (4, 3), (2, 4, 1)), - ((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)), - ((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)), + for (d1s, d2s) in ( + ((1, 2), (1, 2)), + ((1, 2), (2, 1)), + ((1, 2), (2, 1, 3)), + ((1, 2, 3), (2, 1)), + ((1, 2), (2, 3)), + ((2, 1), (2, 3)), + ((1, 2, 3), (2, 3, 4)), + ((3, 2, 1), (4, 2, 3)), + ((1, 2, 3), (3, 4)), + ((3, 1, 2), (3, 4)), + ((3, 2, 1), (4, 3)), + ((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9)), + ((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7)), ) a1 = randn(elt1, map(i -> dims[i], d1s)) labels1 = map(i -> labels[i], d1s) a2 = randn(elt2, map(i -> dims[i], d2s)) labels2 = map(i -> labels[i], d2s) - labels_dest = map(i -> labels[i], d_dests) - # Don't specify destination labels a_dest, labels_dest′ = contract(a1, labels1, a2, labels2) @test labels_dest′ isa BlockedTuple{2,(length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))} @@ -177,35 +170,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) ) @test a_dest ≈ a_dest_tensoroperations - # Specify destination labels - a_dest = contract(labels_dest, a1, labels1, a2, labels2) - a_dest_tensoroperations = TensorOperations.tensorcontract( - labels_dest, a1, labels1, a2, labels2 - ) - @test a_dest ≈ a_dest_tensoroperations - - # Specify with bituple - a_dest = contract(tuplemortar((labels_dest, ())), a1, labels1, a2, labels2) - @test a_dest ≈ a_dest_tensoroperations - a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2) - @test a_dest ≈ a_dest_tensoroperations - a_dest = contract(labels_dest′, a1, labels1, a2, labels2) - a_dest_tensoroperations = TensorOperations.tensorcontract( - Tuple(labels_dest′), a1, labels1, a2, labels2 - ) - @test a_dest ≈ a_dest_tensoroperations - # Specify α and β # TODO: Using random `α`, `β` causing # random test failures, investigate why. α = elt_dest(1.2) # randn(elt_dest) β = elt_dest(2.4) # randn(elt_dest) - a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests)) - a_dest = copy(a_dest_init) - contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) - a_dest_tensoroperations = TensorOperations.tensorcontract( - labels_dest, a1, labels1, a2, labels2 - ) + randn!(a_dest) + a_dest_init = copy(a_dest) + contract!(a_dest, a1, labels1, a2, labels2, α, β) + a_dest_tensoroperations = TensorOperations.tensorcontract(a1, labels1, a2, labels2) ## Here we loosened the tolerance because of some floating point roundoff issue. ## with Float32 numbers @test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol = @@ -226,17 +199,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test eltype(a_dest) === elt_dest @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)) - a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l")) - @test eltype(a_dest) === elt_dest - @test a_dest ≈ permutedims( - reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4) - ) - - a_dest = zeros(elt_dest, 2, 5, 3, 4) - contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l")) - @test a_dest ≈ permutedims( - reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3) - ) + a_dest = zeros(elt_dest, 2, 3, 4, 5) + contract!(a_dest, a1, ("i", "j"), a2, ("k", "l")) + @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)) end @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts @@ -265,38 +230,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test labels_dest == tuplemortar(((), ())) @test a_dest[] ≈ s[] * t[] - # Specify output labels. - labels_dest_example = ("j", "l", "i", "k") - size_dest_example = (3, 5, 2, 4) - - # Array-scalar contraction. - a_dest = contract(labels_dest_example, a, labels_a, s, ()) - @test size(a_dest) == size_dest_example - @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] - - # Scalar-array contraction. - a_dest = contract(labels_dest_example, s, (), a, labels_a) - @test size(a_dest) == size_dest_example - @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] - - # Scalar-scalar contraction. - a_dest = contract((), s, (), t, ()) - @test size(a_dest) == () - @test a_dest[] ≈ s[] * t[] - # Array-scalar contraction. - a_dest = zeros(elt_dest, size_dest_example) - contract!(a_dest, labels_dest_example, a, labels_a, s, ()) - @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + a_dest = zeros(elt_dest, size(a)) + contract!(a_dest, a, (1, 2, 3, 4), s, ()) + @test a_dest ≈ a * s[] # Scalar-array contraction. - a_dest = zeros(elt_dest, size_dest_example) - contract!(a_dest, labels_dest_example, s, (), a, labels_a) - @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + a_dest = zeros(elt_dest, size(a)) + contract!(a_dest, s, (), a, (1, 2, 3, 4)) + @test a_dest ≈ a * s[] # Scalar-scalar contraction. a_dest = zeros(elt_dest, ()) - contract!(a_dest, (), s, (), t, ()) + contract!(a_dest, s, (), t, ()) @test a_dest[] ≈ s[] * t[] end end From a664432ad7b070d24a56b8c446205cfd7f6c359e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 4 Aug 2025 14:04:53 -0400 Subject: [PATCH 2/3] fix tests --- test/test_factorizations.jl | 76 ++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index efb81fc..e727093 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -35,8 +35,8 @@ elts = (Float64, ComplexF64) Acopy = deepcopy(A) Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=true) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) - @test A ≈ A′ + A′, legs = contract(Q, (labels_Q..., :q), R, (:q, labels_R...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(Q, 1) * size(Q, 2) == size(Q, 3) # Q is unitary end @@ -49,8 +49,8 @@ end Acopy = deepcopy(A) Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=false) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) - @test A ≈ A′ + A′, legs = contract(Q, (labels_Q..., :q), R, (:q, labels_R...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(Q, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) end @@ -65,8 +65,8 @@ end Acopy = deepcopy(A) L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full=true) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) - @test A ≈ A′ + A′, legs = contract(L, (labels_L..., :q), Q, (:q, labels_Q...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(Q, 1) == size(Q, 2) * size(Q, 3) # Q is unitary end @@ -79,8 +79,8 @@ end Acopy = deepcopy(A) L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full=false) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) - @test A ≈ A′ + A′, legs = contract(L, (labels_L..., :q), Q, (:q, labels_Q...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(Q, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) # Q is unitary end @@ -98,9 +98,9 @@ end @test A == Acopy # should not have altered initial array @test eltype(D) == eltype(V) && eltype(D) <: Complex - AV = contract((:a, :b, :D), A, labels_A, V, (labels_V′..., :D)) - VD = contract((:a, :b, :D), V, (labels_V..., :D′), D, (:D′, :D)) - @test AV ≈ VD + AV, _ = contract(A, labels_A, V, (labels_V′..., :D)) + VD, _ = contract(V, (labels_V..., :D′), D, (:D′, :D)) + @test AV ≈ permutedims(VD, (2, 1, 3)) # type-unstable because of `ishermitian` difference Dvals = eigvals(A, labels_A, labels_V, labels_V′; ishermitian=false) @@ -122,9 +122,9 @@ end @test eltype(D) <: Real @test eltype(V) == eltype(A) - AV = contract((:a, :b, :D), A, labels_A, V, (labels_V′..., :D)) - VD = contract((:a, :b, :D), V, (labels_V..., :D′), D, (:D′, :D)) - @test AV ≈ VD + AV, _ = contract(A, labels_A, V, (labels_V′..., :D)) + VD, _ = contract(V, (labels_V..., :D′), D, (:D′, :D)) + @test AV ≈ permutedims(VD, (2, 1, 3)) # type-unstable because of `ishermitian` difference Dvals = eigvals(A, labels_A, labels_V, labels_V′; ishermitian=true) @@ -144,22 +144,22 @@ end U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=true) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) - @test A ≈ A′ + A′, _ = contract(US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary @test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full=true) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) + A′, _ = contract(US, labels_US, Vᴴ, (:v,)) @test A ≈ A′ @test size(Vᴴ, 1) == 1 U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full=true) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (:u,), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) + A′, _ = contract(US, labels_US, Vᴴ, (:v, labels_A...)) @test A ≈ A′ @test size(U, 2) == 1 end @@ -174,8 +174,8 @@ end U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=false) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) - @test A ≈ A′ + A′, _ = contract(US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) k = min(size(S)...) @test size(U, 3) == k == size(Vᴴ, 1) @@ -185,14 +185,14 @@ end U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full=false) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) + A′, _ = contract(US, labels_US, Vᴴ, (:v,)) @test A ≈ A′ @test size(U, ndims(U)) == 1 == size(Vᴴ, 1) U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full=false) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (:u,), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) + A′, _ = contract(US, labels_US, Vᴴ, (:v, labels_A...)) @test A ≈ A′ @test size(U, 1) == 1 == size(Vᴴ, 1) end @@ -212,8 +212,8 @@ end @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) - A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) - @test norm(A - A′) ≈ S_untrunc[end] + A′, _ = contract(US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test norm(A - permutedims(A′, (2, 1, 4, 3))) ≈ S_untrunc[end] @test size(S, 1) == size(S_untrunc, 1) - 1 end @@ -227,17 +227,17 @@ end N = @constinferred left_null(A, labels_A, labels_codomain, labels_domain) @test A == Acopy # should not have altered initial array # N^ba_n' * A^ba_dc = 0 - NA = contract((:n, labels_domain...), conj(N), (labels_codomain..., :n), A, labels_A) + NA, _ = contract(conj(N), (labels_codomain..., :n), A, labels_A) @test norm(NA) ≈ 0 atol = 1e-14 - NN = contract((:n, :n′), conj(N), (labels_codomain..., :n), N, (labels_codomain..., :n′)) + NN, _ = contract(conj(N), (labels_codomain..., :n), N, (labels_codomain..., :n′)) @test NN ≈ LinearAlgebra.I Nᴴ = @constinferred right_null(A, labels_A, labels_codomain, labels_domain) @test A == Acopy # should not have altered initial array # A^ba_dc * N^dc_n' = 0 - AN = contract((labels_codomain..., :n), A, labels_A, conj(Nᴴ), (:n, labels_domain...)) + AN, _ = contract(A, labels_A, conj(Nᴴ), (:n, labels_domain...)) @test norm(AN) ≈ 0 atol = 1e-14 - NN = contract((:n, :n′), Nᴴ, (:n, labels_domain...), Nᴴ, (:n′, labels_domain...)) + NN, _ = contract(Nᴴ, (:n, labels_domain...), Nᴴ, (:n′, labels_domain...)) end @testset "Left polar ($T)" for T in elts @@ -253,8 +253,8 @@ end polar(A, labels_A, labels_W, labels_P), ) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) - @test A ≈ A′ + A′, _ = contract(W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) end end @@ -271,8 +271,8 @@ end polar(A, labels_A, labels_P, labels_W; side=:right), ) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) - @test A ≈ A′ + A′, _ = contract(P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) end end @@ -290,8 +290,8 @@ end orth(A, labels_A, labels_W, labels_P), ) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) - @test A ≈ A′ + A′, _ = contract(W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) end end @@ -308,8 +308,8 @@ end orth(A, labels_A, labels_P, labels_W; side=:right), ) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) - @test A ≈ A′ + A′, _ = contract(P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) end end @@ -324,8 +324,8 @@ end for orth in (:left, :right) X, Y = factorize(A, labels_A, labels_X, labels_Y; orth) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, X, (labels_X..., :x), Y, (:x, labels_Y...)) - @test A ≈ A′ + A′, _ = contract(X, (labels_X..., :x), Y, (:x, labels_Y...)) + @test A ≈ permutedims(A′, (2, 1, 4, 3)) @test size(X, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) end end From 356deff0eaed2e862163fea26347054478ac858a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 4 Aug 2025 14:07:47 -0400 Subject: [PATCH 3/3] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 570e05e..1c77028 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.10" +version = "0.3.11" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"