1
+ using Random: randn!
1
2
using Test: @test , @test_broken , @test_throws , @testset
2
3
3
4
using EllipsisNotation: var".."
@@ -134,41 +135,33 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
134
135
elt_dest = promote_type (elt1, elt2)
135
136
a1 = ones (elt1, (1 , 1 ))
136
137
a2 = ones (elt2, (1 , 1 ))
137
- a_dest = ones (elt_dest, (1 , 1 ))
138
+ a_dest = ones (elt_dest, (1 , 1 , 1 ))
138
139
@test_throws ArgumentError contract (a1, (1 , 2 , 4 ), a2, (2 , 3 ))
139
140
@test_throws ArgumentError contract (a1, (1 , 2 ), a2, (2 , 3 , 4 ))
140
- @test_throws ArgumentError contract ((1 , 3 , 4 ), a1, (1 , 2 ), a2, (2 , 3 ))
141
- @test_throws ArgumentError contract ((1 , 3 ), a1, (1 , 2 ), a2, (2 , 4 ))
142
- @test_throws ArgumentError contract! (a_dest, (1 , 3 , 4 ), a1, (1 , 2 ), a2, (2 , 3 ))
141
+ @test_throws ArgumentError contract! (a_dest, a1, (1 , 2 ), a2, (2 , 3 ))
143
142
144
143
dims = (2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 )
145
144
labels = (:a , :b , :c , :d , :e , :f , :g , :h , :i )
146
- for (d1s, d2s, d_dests) in (
147
- ((1 , 2 ), (1 , 2 ), ()),
148
- ((1 , 2 ), (2 , 1 ), ()),
149
- ((1 , 2 ), (2 , 1 , 3 ), (3 ,)),
150
- ((1 , 2 , 3 ), (2 , 1 ), (3 ,)),
151
- ((1 , 2 ), (2 , 3 ), (1 , 3 )),
152
- ((1 , 2 ), (2 , 3 ), (3 , 1 )),
153
- ((2 , 1 ), (2 , 3 ), (3 , 1 )),
154
- ((1 , 2 , 3 ), (2 , 3 , 4 ), (1 , 4 )),
155
- ((1 , 2 , 3 ), (2 , 3 , 4 ), (4 , 1 )),
156
- ((3 , 2 , 1 ), (4 , 2 , 3 ), (4 , 1 )),
157
- ((1 , 2 , 3 ), (3 , 4 ), (1 , 2 , 4 )),
158
- ((1 , 2 , 3 ), (3 , 4 ), (4 , 1 , 2 )),
159
- ((1 , 2 , 3 ), (3 , 4 ), (2 , 4 , 1 )),
160
- ((3 , 1 , 2 ), (3 , 4 ), (2 , 4 , 1 )),
161
- ((3 , 2 , 1 ), (4 , 3 ), (2 , 4 , 1 )),
162
- ((1 , 2 , 3 , 4 , 5 , 6 ), (4 , 5 , 6 , 7 , 8 , 9 ), (1 , 2 , 3 , 7 , 8 , 9 )),
163
- ((2 , 4 , 5 , 1 , 6 , 3 ), (6 , 4 , 9 , 8 , 5 , 7 ), (1 , 7 , 2 , 8 , 3 , 9 )),
145
+ for (d1s, d2s) in (
146
+ ((1 , 2 ), (1 , 2 )),
147
+ ((1 , 2 ), (2 , 1 )),
148
+ ((1 , 2 ), (2 , 1 , 3 )),
149
+ ((1 , 2 , 3 ), (2 , 1 )),
150
+ ((1 , 2 ), (2 , 3 )),
151
+ ((2 , 1 ), (2 , 3 )),
152
+ ((1 , 2 , 3 ), (2 , 3 , 4 )),
153
+ ((3 , 2 , 1 ), (4 , 2 , 3 )),
154
+ ((1 , 2 , 3 ), (3 , 4 )),
155
+ ((3 , 1 , 2 ), (3 , 4 )),
156
+ ((3 , 2 , 1 ), (4 , 3 )),
157
+ ((1 , 2 , 3 , 4 , 5 , 6 ), (4 , 5 , 6 , 7 , 8 , 9 )),
158
+ ((2 , 4 , 5 , 1 , 6 , 3 ), (6 , 4 , 9 , 8 , 5 , 7 )),
164
159
)
165
160
a1 = randn (elt1, map (i -> dims[i], d1s))
166
161
labels1 = map (i -> labels[i], d1s)
167
162
a2 = randn (elt2, map (i -> dims[i], d2s))
168
163
labels2 = map (i -> labels[i], d2s)
169
- labels_dest = map (i -> labels[i], d_dests)
170
164
171
- # Don't specify destination labels
172
165
a_dest, labels_dest′ = contract (a1, labels1, a2, labels2)
173
166
@test labels_dest′ isa
174
167
BlockedTuple{2 ,(length (setdiff (d1s, d2s)), length (setdiff (d2s, d1s)))}
@@ -177,35 +170,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
177
170
)
178
171
@test a_dest ≈ a_dest_tensoroperations
179
172
180
- # Specify destination labels
181
- a_dest = contract (labels_dest, a1, labels1, a2, labels2)
182
- a_dest_tensoroperations = TensorOperations. tensorcontract (
183
- labels_dest, a1, labels1, a2, labels2
184
- )
185
- @test a_dest ≈ a_dest_tensoroperations
186
-
187
- # Specify with bituple
188
- a_dest = contract (tuplemortar ((labels_dest, ())), a1, labels1, a2, labels2)
189
- @test a_dest ≈ a_dest_tensoroperations
190
- a_dest = contract (tuplemortar (((), labels_dest)), a1, labels1, a2, labels2)
191
- @test a_dest ≈ a_dest_tensoroperations
192
- a_dest = contract (labels_dest′, a1, labels1, a2, labels2)
193
- a_dest_tensoroperations = TensorOperations. tensorcontract (
194
- Tuple (labels_dest′), a1, labels1, a2, labels2
195
- )
196
- @test a_dest ≈ a_dest_tensoroperations
197
-
198
173
# Specify α and β
199
174
# TODO : Using random `α`, `β` causing
200
175
# random test failures, investigate why.
201
176
α = elt_dest (1.2 ) # randn(elt_dest)
202
177
β = elt_dest (2.4 ) # randn(elt_dest)
203
- a_dest_init = randn (elt_dest, map (i -> dims[i], d_dests))
204
- a_dest = copy (a_dest_init)
205
- contract! (a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
206
- a_dest_tensoroperations = TensorOperations. tensorcontract (
207
- labels_dest, a1, labels1, a2, labels2
208
- )
178
+ randn! (a_dest)
179
+ a_dest_init = copy (a_dest)
180
+ contract! (a_dest, a1, labels1, a2, labels2, α, β)
181
+ a_dest_tensoroperations = TensorOperations. tensorcontract (a1, labels1, a2, labels2)
209
182
# # Here we loosened the tolerance because of some floating point roundoff issue.
210
183
# # with Float32 numbers
211
184
@test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol =
@@ -226,17 +199,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
226
199
@test eltype (a_dest) === elt_dest
227
200
@test a_dest ≈ reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... ))
228
201
229
- a_dest = contract ((" i" , " k" , " j" , " l" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
230
- @test eltype (a_dest) === elt_dest
231
- @test a_dest ≈ permutedims (
232
- reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... )), (1 , 3 , 2 , 4 )
233
- )
234
-
235
- a_dest = zeros (elt_dest, 2 , 5 , 3 , 4 )
236
- contract! (a_dest, (" i" , " l" , " j" , " k" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
237
- @test a_dest ≈ permutedims (
238
- reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... )), (1 , 4 , 2 , 3 )
239
- )
202
+ a_dest = zeros (elt_dest, 2 , 3 , 4 , 5 )
203
+ contract! (a_dest, a1, (" i" , " j" ), a2, (" k" , " l" ))
204
+ @test a_dest ≈ reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... ))
240
205
end
241
206
@testset " scalar contraction (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts,
242
207
elt2 in elts
@@ -265,38 +230,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
265
230
@test labels_dest == tuplemortar (((), ()))
266
231
@test a_dest[] ≈ s[] * t[]
267
232
268
- # Specify output labels.
269
- labels_dest_example = (" j" , " l" , " i" , " k" )
270
- size_dest_example = (3 , 5 , 2 , 4 )
271
-
272
- # Array-scalar contraction.
273
- a_dest = contract (labels_dest_example, a, labels_a, s, ())
274
- @test size (a_dest) == size_dest_example
275
- @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
276
-
277
- # Scalar-array contraction.
278
- a_dest = contract (labels_dest_example, s, (), a, labels_a)
279
- @test size (a_dest) == size_dest_example
280
- @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
281
-
282
- # Scalar-scalar contraction.
283
- a_dest = contract ((), s, (), t, ())
284
- @test size (a_dest) == ()
285
- @test a_dest[] ≈ s[] * t[]
286
-
287
233
# Array-scalar contraction.
288
- a_dest = zeros (elt_dest, size_dest_example )
289
- contract! (a_dest, labels_dest_example, a, labels_a , s, ())
290
- @test a_dest ≈ permutedims (a, ( 2 , 4 , 1 , 3 )) * s[]
234
+ a_dest = zeros (elt_dest, size (a) )
235
+ contract! (a_dest, a, ( 1 , 2 , 3 , 4 ) , s, ())
236
+ @test a_dest ≈ a * s[]
291
237
292
238
# Scalar-array contraction.
293
- a_dest = zeros (elt_dest, size_dest_example )
294
- contract! (a_dest, labels_dest_example, s, (), a, labels_a )
295
- @test a_dest ≈ permutedims (a, ( 2 , 4 , 1 , 3 )) * s[]
239
+ a_dest = zeros (elt_dest, size (a) )
240
+ contract! (a_dest, s, (), a, ( 1 , 2 , 3 , 4 ) )
241
+ @test a_dest ≈ a * s[]
296
242
297
243
# Scalar-scalar contraction.
298
244
a_dest = zeros (elt_dest, ())
299
- contract! (a_dest, (), s, (), t, ())
245
+ contract! (a_dest, s, (), t, ())
300
246
@test a_dest[] ≈ s[] * t[]
301
247
end
302
248
end
0 commit comments