diff --git a/Project.toml b/Project.toml index 724161b..20f8803 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.7" +version = "0.3.8" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index 3fa1c02..bfaa7ef 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -1,5 +1,18 @@ using Base.PermutedDimsArrays: genperm +function check_input(::typeof(contract), a1, labels1, a2, labels2) + ndims(a1) == length(labels1) || + throw(ArgumentError("Invalid permutation for left tensor")) + return ndims(a2) == length(labels2) || + throw(ArgumentError("Invalid permutation for right tensor")) +end + +function check_input(::typeof(contract), a_dest, labels_dest, a1, labels1, a2, labels2) + ndims(a_dest) == length(labels_dest) || + throw(ArgumentError("Invalid permutation for destination tensor")) + return check_input(contract, a1, labels1, a2, labels2) +end + # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function output_axes( @@ -28,6 +41,9 @@ function allocate_output( biperm2::AbstractBlockPermutation, α::Number=one(Bool), ) + check_input(contract, a1, biperm1, a2, biperm2) + blocklengths(biperm_dest) == (length(biperm1[Block(1)]), length(biperm2[Block(2)])) || + throw(ArgumentError("Invalid permutation for destination tensor")) axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α) return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest) end diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index a41033a..5f8d78c 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -8,6 +8,11 @@ end # codomain <-- domain function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) + dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2))) + for i in unique(dimnames) + count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels")) + end + codomain = Tuple(setdiff(dimnames1, dimnames2)) contracted = Tuple(intersect(dimnames1, dimnames2)) domain = Tuple(setdiff(dimnames2, dimnames1)) diff --git a/src/contract/contract.jl b/src/contract/contract.jl index e47cc89..02665ca 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -88,6 +88,7 @@ function contract( α::Number=one(Bool); kwargs..., ) + check_input(contract, a1, labels1, a2, labels2) biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...) end @@ -104,6 +105,7 @@ function contract!( β::Number; kwargs..., ) + check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2) biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...) end @@ -118,6 +120,7 @@ function contract( α::Number; kwargs..., ) + check_input(contract, a1, biperm1, a2, biperm2) a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α) contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...) return a_dest diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 98978bd..1bf6f70 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -11,6 +11,7 @@ function contract!( α::Number, β::Number, ) + check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2) a_dest_mat = matricize(a_dest, biperm_dest) a1_mat = matricize(a1, biperm1) a2_mat = matricize(a2, biperm2) diff --git a/src/contract/output_labels.jl b/src/contract/output_labels.jl index c2ffd6b..68525c5 100644 --- a/src/contract/output_labels.jl +++ b/src/contract/output_labels.jl @@ -10,10 +10,12 @@ function output_labels( return output_labels(f, alg, labels1, labels2) end -function output_labels(f::typeof(contract), alg::Algorithm, labels1, labels2) +function output_labels(f::typeof(contract), ::Algorithm, labels1, labels2) return output_labels(f, labels1, labels2) end function output_labels(::typeof(contract), labels1, labels2) - return Tuple(symdiff(labels1, labels2)) + diff1 = Tuple(setdiff(labels1, labels2)) + diff2 = Tuple(setdiff(labels2, labels1)) + return tuplemortar((diff1, diff2)) end diff --git a/src/matricize.jl b/src/matricize.jl index 666f058..85470a4 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -46,6 +46,7 @@ end # maybe: copy=false kwarg function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2}) + ndims(a) == length(biperm) || throw(ArgumentError("Invalid bipermutation")) return matricize(FusionStyle(a), a, biperm) end @@ -78,6 +79,7 @@ function unmatricize( axes::Tuple{Vararg{AbstractUnitRange}}, biperm::AbstractBlockPermutation{2}, ) + length(axes) == length(biperm) || throw(ArgumentError("axes do not match permutation")) return unmatricize(FusionStyle(m), m, axes, biperm) end @@ -122,6 +124,8 @@ end function unmatricize!( a::AbstractArray, m::AbstractMatrix, biperm::AbstractBlockPermutation{2} ) + ndims(a) == length(biperm) || + throw(ArgumentError("destination does not match permutation")) blocked_axes = axes(a)[biperm] a_perm = unmatricize(m, blocked_axes) return permuteblockeddims!(a, a_perm, invperm(biperm)) diff --git a/test/test_basics.jl b/test/test_basics.jl index 53d48a9..26c52b9 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,6 +5,7 @@ using StableRNGs: StableRNG using TensorOperations: TensorOperations using TensorAlgebra: + BlockedTuple, blockedpermvcat, permuteblockeddims, permuteblockeddims!, @@ -61,6 +62,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test_throws MethodError matricize(a, (1, 2), (3,), (4,)) @test_throws MethodError matricize(a, (1, 2, 3, 4)) + @test_throws ArgumentError matricize(a, blockedpermvcat((1, 2), (3,))) v = ones(elt, 2) a_fused = matricize(v, (1,), ()) @@ -122,10 +124,23 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a = unmatricize(m, (), ()) @test a isa Array{elt,0} @test a[] == m[1, 1] + + @test_throws ArgumentError unmatricize(m, (), blockedpermvcat((1, 2), (3,))) + @test_throws ArgumentError unmatricize!(m, m, blockedpermvcat((1, 2), (3,))) end using TensorOperations: TensorOperations @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts + elt_dest = promote_type(elt1, elt2) + a1 = ones(elt1, (1, 1)) + a2 = ones(elt2, (1, 1)) + a_dest = ones(elt_dest, (1, 1)) + @test_throws ArgumentError contract(a1, (1, 2, 4), a2, (2, 3)) + @test_throws ArgumentError contract(a1, (1, 2), a2, (2, 3, 4)) + @test_throws ArgumentError contract((1, 3, 4), a1, (1, 2), a2, (2, 3)) + @test_throws ArgumentError contract((1, 3), a1, (1, 2), a2, (2, 4)) + @test_throws ArgumentError contract!(a_dest, (1, 3, 4), a1, (1, 2), a2, (2, 3)) + dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i) for (d1s, d2s, d_dests) in ( @@ -155,8 +170,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # Don't specify destination labels a_dest, labels_dest′ = contract(a1, labels1, a2, labels2) + @test labels_dest′ isa + BlockedTuple{2,(length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))} a_dest_tensoroperations = TensorOperations.tensorcontract( - labels_dest′, a1, labels1, a2, labels2 + Tuple(labels_dest′), a1, labels1, a2, labels2 ) @test a_dest ≈ a_dest_tensoroperations @@ -167,8 +184,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) ) @test a_dest ≈ a_dest_tensoroperations + # Specify with bituple + a_dest = contract(tuplemortar((labels_dest, ())), a1, labels1, a2, labels2) + @test a_dest ≈ a_dest_tensoroperations + a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2) + @test a_dest ≈ a_dest_tensoroperations + a_dest = contract(labels_dest′, a1, labels1, a2, labels2) + a_dest_tensoroperations = TensorOperations.tensorcontract( + Tuple(labels_dest′), a1, labels1, a2, labels2 + ) + @test a_dest ≈ a_dest_tensoroperations + # Specify α and β - elt_dest = promote_type(elt1, elt2) # TODO: Using random `α`, `β` causing # random test failures, investigate why. α = elt_dest(1.2) # randn(elt_dest) @@ -195,7 +222,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a2 = randn(rng, elt2, 4, 5) a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l")) - @test labels == ("i", "j", "k", "l") + @test labels == tuplemortar((("i", "j"), ("k", "l"))) @test eltype(a_dest) === elt_dest @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)) @@ -225,17 +252,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # Array-scalar contraction. a_dest, labels_dest = contract(a, labels_a, s, ()) - @test labels_dest == labels_a + @test labels_dest == tuplemortar((labels_a, ())) @test a_dest ≈ a * s[] # Scalar-array contraction. a_dest, labels_dest = contract(s, (), a, labels_a) - @test labels_dest == labels_a + @test labels_dest == tuplemortar(((), labels_a)) @test a_dest ≈ a * s[] # Scalar-scalar contraction. a_dest, labels_dest = contract(s, (), t, ()) - @test labels_dest == () + @test labels_dest == tuplemortar(((), ())) @test a_dest[] ≈ s[] * t[] # Specify output labels.