@@ -5,6 +5,7 @@ using StableRNGs: StableRNG
55using TensorOperations: TensorOperations
66
77using TensorAlgebra:
8+ BlockedTuple,
89 blockedpermvcat,
910 permuteblockeddims,
1011 permuteblockeddims!,
@@ -61,6 +62,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
6162
6263 @test_throws MethodError matricize (a, (1 , 2 ), (3 ,), (4 ,))
6364 @test_throws MethodError matricize (a, (1 , 2 , 3 , 4 ))
65+ @test_throws ArgumentError matricize (a, blockedpermvcat ((1 , 2 ), (3 ,)))
6466
6567 v = ones (elt, 2 )
6668 a_fused = matricize (v, (1 ,), ())
@@ -122,10 +124,23 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
122124 a = unmatricize (m, (), ())
123125 @test a isa Array{elt,0 }
124126 @test a[] == m[1 , 1 ]
127+
128+ @test_throws ArgumentError unmatricize (m, (), blockedpermvcat ((1 , 2 ), (3 ,)))
129+ @test_throws ArgumentError unmatricize! (m, m, blockedpermvcat ((1 , 2 ), (3 ,)))
125130 end
126131
127132 using TensorOperations: TensorOperations
128133 @testset " contract (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts, elt2 in elts
134+ elt_dest = promote_type (elt1, elt2)
135+ a1 = ones (elt1, (1 , 1 ))
136+ a2 = ones (elt2, (1 , 1 ))
137+ a_dest = ones (elt_dest, (1 , 1 ))
138+ @test_throws ArgumentError contract (a1, (1 , 2 , 4 ), a2, (2 , 3 ))
139+ @test_throws ArgumentError contract (a1, (1 , 2 ), a2, (2 , 3 , 4 ))
140+ @test_throws ArgumentError contract ((1 , 3 , 4 ), a1, (1 , 2 ), a2, (2 , 3 ))
141+ @test_throws ArgumentError contract ((1 , 3 ), a1, (1 , 2 ), a2, (2 , 4 ))
142+ @test_throws ArgumentError contract! (a_dest, (1 , 3 , 4 ), a1, (1 , 2 ), a2, (2 , 3 ))
143+
129144 dims = (2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 )
130145 labels = (:a , :b , :c , :d , :e , :f , :g , :h , :i )
131146 for (d1s, d2s, d_dests) in (
@@ -155,8 +170,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
155170
156171 # Don't specify destination labels
157172 a_dest, labels_dest′ = contract (a1, labels1, a2, labels2)
173+ @test labels_dest′ isa
174+ BlockedTuple{2 ,(length (setdiff (d1s, d2s)), length (setdiff (d2s, d1s)))}
158175 a_dest_tensoroperations = TensorOperations. tensorcontract (
159- labels_dest′, a1, labels1, a2, labels2
176+ Tuple ( labels_dest′) , a1, labels1, a2, labels2
160177 )
161178 @test a_dest ≈ a_dest_tensoroperations
162179
@@ -167,8 +184,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
167184 )
168185 @test a_dest ≈ a_dest_tensoroperations
169186
187+ # Specify with bituple
188+ a_dest = contract (tuplemortar ((labels_dest, ())), a1, labels1, a2, labels2)
189+ @test a_dest ≈ a_dest_tensoroperations
190+ a_dest = contract (tuplemortar (((), labels_dest)), a1, labels1, a2, labels2)
191+ @test a_dest ≈ a_dest_tensoroperations
192+ a_dest = contract (labels_dest′, a1, labels1, a2, labels2)
193+ a_dest_tensoroperations = TensorOperations. tensorcontract (
194+ Tuple (labels_dest′), a1, labels1, a2, labels2
195+ )
196+ @test a_dest ≈ a_dest_tensoroperations
197+
170198 # Specify α and β
171- elt_dest = promote_type (elt1, elt2)
172199 # TODO : Using random `α`, `β` causing
173200 # random test failures, investigate why.
174201 α = elt_dest (1.2 ) # randn(elt_dest)
@@ -195,7 +222,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
195222 a2 = randn (rng, elt2, 4 , 5 )
196223
197224 a_dest, labels = contract (a1, (" i" , " j" ), a2, (" k" , " l" ))
198- @test labels == ( " i" , " j" , " k" , " l" )
225+ @test labels == tuplemortar ((( " i" , " j" ), ( " k" , " l" )) )
199226 @test eltype (a_dest) === elt_dest
200227 @test a_dest ≈ reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... ))
201228
@@ -225,17 +252,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
225252
226253 # Array-scalar contraction.
227254 a_dest, labels_dest = contract (a, labels_a, s, ())
228- @test labels_dest == labels_a
255+ @test labels_dest == tuplemortar (( labels_a, ()))
229256 @test a_dest ≈ a * s[]
230257
231258 # Scalar-array contraction.
232259 a_dest, labels_dest = contract (s, (), a, labels_a)
233- @test labels_dest == labels_a
260+ @test labels_dest == tuplemortar (((), labels_a))
234261 @test a_dest ≈ a * s[]
235262
236263 # Scalar-scalar contraction.
237264 a_dest, labels_dest = contract (s, (), t, ())
238- @test labels_dest == ( )
265+ @test labels_dest == tuplemortar (((), ()) )
239266 @test a_dest[] ≈ s[] * t[]
240267
241268 # Specify output labels.
0 commit comments