1+ using Random: randn!
12using Test: @test , @test_broken , @test_throws , @testset
23
34using EllipsisNotation: var".."
@@ -134,41 +135,33 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
134135 elt_dest = promote_type (elt1, elt2)
135136 a1 = ones (elt1, (1 , 1 ))
136137 a2 = ones (elt2, (1 , 1 ))
137- a_dest = ones (elt_dest, (1 , 1 ))
138+ a_dest = ones (elt_dest, (1 , 1 , 1 ))
138139 @test_throws ArgumentError contract (a1, (1 , 2 , 4 ), a2, (2 , 3 ))
139140 @test_throws ArgumentError contract (a1, (1 , 2 ), a2, (2 , 3 , 4 ))
140- @test_throws ArgumentError contract ((1 , 3 , 4 ), a1, (1 , 2 ), a2, (2 , 3 ))
141- @test_throws ArgumentError contract ((1 , 3 ), a1, (1 , 2 ), a2, (2 , 4 ))
142- @test_throws ArgumentError contract! (a_dest, (1 , 3 , 4 ), a1, (1 , 2 ), a2, (2 , 3 ))
141+ @test_throws ArgumentError contract! (a_dest, a1, (1 , 2 ), a2, (2 , 3 ))
143142
144143 dims = (2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 )
145144 labels = (:a , :b , :c , :d , :e , :f , :g , :h , :i )
146- for (d1s, d2s, d_dests) in (
147- ((1 , 2 ), (1 , 2 ), ()),
148- ((1 , 2 ), (2 , 1 ), ()),
149- ((1 , 2 ), (2 , 1 , 3 ), (3 ,)),
150- ((1 , 2 , 3 ), (2 , 1 ), (3 ,)),
151- ((1 , 2 ), (2 , 3 ), (1 , 3 )),
152- ((1 , 2 ), (2 , 3 ), (3 , 1 )),
153- ((2 , 1 ), (2 , 3 ), (3 , 1 )),
154- ((1 , 2 , 3 ), (2 , 3 , 4 ), (1 , 4 )),
155- ((1 , 2 , 3 ), (2 , 3 , 4 ), (4 , 1 )),
156- ((3 , 2 , 1 ), (4 , 2 , 3 ), (4 , 1 )),
157- ((1 , 2 , 3 ), (3 , 4 ), (1 , 2 , 4 )),
158- ((1 , 2 , 3 ), (3 , 4 ), (4 , 1 , 2 )),
159- ((1 , 2 , 3 ), (3 , 4 ), (2 , 4 , 1 )),
160- ((3 , 1 , 2 ), (3 , 4 ), (2 , 4 , 1 )),
161- ((3 , 2 , 1 ), (4 , 3 ), (2 , 4 , 1 )),
162- ((1 , 2 , 3 , 4 , 5 , 6 ), (4 , 5 , 6 , 7 , 8 , 9 ), (1 , 2 , 3 , 7 , 8 , 9 )),
163- ((2 , 4 , 5 , 1 , 6 , 3 ), (6 , 4 , 9 , 8 , 5 , 7 ), (1 , 7 , 2 , 8 , 3 , 9 )),
145+ for (d1s, d2s) in (
146+ ((1 , 2 ), (1 , 2 )),
147+ ((1 , 2 ), (2 , 1 )),
148+ ((1 , 2 ), (2 , 1 , 3 )),
149+ ((1 , 2 , 3 ), (2 , 1 )),
150+ ((1 , 2 ), (2 , 3 )),
151+ ((2 , 1 ), (2 , 3 )),
152+ ((1 , 2 , 3 ), (2 , 3 , 4 )),
153+ ((3 , 2 , 1 ), (4 , 2 , 3 )),
154+ ((1 , 2 , 3 ), (3 , 4 )),
155+ ((3 , 1 , 2 ), (3 , 4 )),
156+ ((3 , 2 , 1 ), (4 , 3 )),
157+ ((1 , 2 , 3 , 4 , 5 , 6 ), (4 , 5 , 6 , 7 , 8 , 9 )),
158+ ((2 , 4 , 5 , 1 , 6 , 3 ), (6 , 4 , 9 , 8 , 5 , 7 )),
164159 )
165160 a1 = randn (elt1, map (i -> dims[i], d1s))
166161 labels1 = map (i -> labels[i], d1s)
167162 a2 = randn (elt2, map (i -> dims[i], d2s))
168163 labels2 = map (i -> labels[i], d2s)
169- labels_dest = map (i -> labels[i], d_dests)
170164
171- # Don't specify destination labels
172165 a_dest, labels_dest′ = contract (a1, labels1, a2, labels2)
173166 @test labels_dest′ isa
174167 BlockedTuple{2 ,(length (setdiff (d1s, d2s)), length (setdiff (d2s, d1s)))}
@@ -177,35 +170,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
177170 )
178171 @test a_dest ≈ a_dest_tensoroperations
179172
180- # Specify destination labels
181- a_dest = contract (labels_dest, a1, labels1, a2, labels2)
182- a_dest_tensoroperations = TensorOperations. tensorcontract (
183- labels_dest, a1, labels1, a2, labels2
184- )
185- @test a_dest ≈ a_dest_tensoroperations
186-
187- # Specify with bituple
188- a_dest = contract (tuplemortar ((labels_dest, ())), a1, labels1, a2, labels2)
189- @test a_dest ≈ a_dest_tensoroperations
190- a_dest = contract (tuplemortar (((), labels_dest)), a1, labels1, a2, labels2)
191- @test a_dest ≈ a_dest_tensoroperations
192- a_dest = contract (labels_dest′, a1, labels1, a2, labels2)
193- a_dest_tensoroperations = TensorOperations. tensorcontract (
194- Tuple (labels_dest′), a1, labels1, a2, labels2
195- )
196- @test a_dest ≈ a_dest_tensoroperations
197-
198173 # Specify α and β
199174 # TODO : Using random `α`, `β` causing
200175 # random test failures, investigate why.
201176 α = elt_dest (1.2 ) # randn(elt_dest)
202177 β = elt_dest (2.4 ) # randn(elt_dest)
203- a_dest_init = randn (elt_dest, map (i -> dims[i], d_dests))
204- a_dest = copy (a_dest_init)
205- contract! (a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
206- a_dest_tensoroperations = TensorOperations. tensorcontract (
207- labels_dest, a1, labels1, a2, labels2
208- )
178+ randn! (a_dest)
179+ a_dest_init = copy (a_dest)
180+ contract! (a_dest, a1, labels1, a2, labels2, α, β)
181+ a_dest_tensoroperations = TensorOperations. tensorcontract (a1, labels1, a2, labels2)
209182 # # Here we loosened the tolerance because of some floating point roundoff issue.
210183 # # with Float32 numbers
211184 @test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol =
@@ -226,17 +199,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
226199 @test eltype (a_dest) === elt_dest
227200 @test a_dest ≈ reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... ))
228201
229- a_dest = contract ((" i" , " k" , " j" , " l" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
230- @test eltype (a_dest) === elt_dest
231- @test a_dest ≈ permutedims (
232- reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... )), (1 , 3 , 2 , 4 )
233- )
234-
235- a_dest = zeros (elt_dest, 2 , 5 , 3 , 4 )
236- contract! (a_dest, (" i" , " l" , " j" , " k" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
237- @test a_dest ≈ permutedims (
238- reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... )), (1 , 4 , 2 , 3 )
239- )
202+ a_dest = zeros (elt_dest, 2 , 3 , 4 , 5 )
203+ contract! (a_dest, a1, (" i" , " j" ), a2, (" k" , " l" ))
204+ @test a_dest ≈ reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... ))
240205 end
241206 @testset " scalar contraction (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts,
242207 elt2 in elts
@@ -265,38 +230,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
265230 @test labels_dest == tuplemortar (((), ()))
266231 @test a_dest[] ≈ s[] * t[]
267232
268- # Specify output labels.
269- labels_dest_example = (" j" , " l" , " i" , " k" )
270- size_dest_example = (3 , 5 , 2 , 4 )
271-
272- # Array-scalar contraction.
273- a_dest = contract (labels_dest_example, a, labels_a, s, ())
274- @test size (a_dest) == size_dest_example
275- @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
276-
277- # Scalar-array contraction.
278- a_dest = contract (labels_dest_example, s, (), a, labels_a)
279- @test size (a_dest) == size_dest_example
280- @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
281-
282- # Scalar-scalar contraction.
283- a_dest = contract ((), s, (), t, ())
284- @test size (a_dest) == ()
285- @test a_dest[] ≈ s[] * t[]
286-
287233 # Array-scalar contraction.
288- a_dest = zeros (elt_dest, size_dest_example )
289- contract! (a_dest, labels_dest_example, a, labels_a , s, ())
290- @test a_dest ≈ permutedims (a, ( 2 , 4 , 1 , 3 )) * s[]
234+ a_dest = zeros (elt_dest, size (a) )
235+ contract! (a_dest, a, ( 1 , 2 , 3 , 4 ) , s, ())
236+ @test a_dest ≈ a * s[]
291237
292238 # Scalar-array contraction.
293- a_dest = zeros (elt_dest, size_dest_example )
294- contract! (a_dest, labels_dest_example, s, (), a, labels_a )
295- @test a_dest ≈ permutedims (a, ( 2 , 4 , 1 , 3 )) * s[]
239+ a_dest = zeros (elt_dest, size (a) )
240+ contract! (a_dest, s, (), a, ( 1 , 2 , 3 , 4 ) )
241+ @test a_dest ≈ a * s[]
296242
297243 # Scalar-scalar contraction.
298244 a_dest = zeros (elt_dest, ())
299- contract! (a_dest, (), s, (), t, ())
245+ contract! (a_dest, s, (), t, ())
300246 @test a_dest[] ≈ s[] * t[]
301247 end
302248end
0 commit comments