@@ -8,7 +8,6 @@ using NDTensors.TensorAlgebra:
88using NDTensors: NDTensors
99include (joinpath (pkgdir (NDTensors), " test" , " NDTensorsTestUtils" , " NDTensorsTestUtils.jl" ))
1010using . NDTensorsTestUtils: default_rtol
11- using TensorOperations: TensorOperations
1211using Test: @test , @test_broken , @testset
1312const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
1413@testset " BlockedPermutation" begin
@@ -111,62 +110,67 @@ end
111110 @test eltype (a_split) === elt
112111 @test a_split ≈ reshape (a, (2 , 3 , 20 ))
113112 end
114- @testset " contract (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts, elt2 in elts
115- dims = (2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 )
116- labels = (:a , :b , :c , :d , :e , :f , :g , :h , :i )
117- for (d1s, d2s, d_dests) in (
118- ((1 , 2 ), (1 , 2 ), ()),
119- ((1 , 2 ), (2 , 1 ), ()),
120- ((1 , 2 ), (2 , 3 ), (1 , 3 )),
121- ((1 , 2 ), (2 , 3 ), (3 , 1 )),
122- ((2 , 1 ), (2 , 3 ), (3 , 1 )),
123- ((1 , 2 , 3 ), (2 , 3 , 4 ), (1 , 4 )),
124- ((1 , 2 , 3 ), (2 , 3 , 4 ), (4 , 1 )),
125- ((3 , 2 , 1 ), (4 , 2 , 3 ), (4 , 1 )),
126- ((1 , 2 , 3 ), (3 , 4 ), (1 , 2 , 4 )),
127- ((1 , 2 , 3 ), (3 , 4 ), (4 , 1 , 2 )),
128- ((1 , 2 , 3 ), (3 , 4 ), (2 , 4 , 1 )),
129- ((3 , 1 , 2 ), (3 , 4 ), (2 , 4 , 1 )),
130- ((3 , 2 , 1 ), (4 , 3 ), (2 , 4 , 1 )),
131- ((1 , 2 , 3 , 4 , 5 , 6 ), (4 , 5 , 6 , 7 , 8 , 9 ), (1 , 2 , 3 , 7 , 8 , 9 )),
132- ((2 , 4 , 5 , 1 , 6 , 3 ), (6 , 4 , 9 , 8 , 5 , 7 ), (1 , 7 , 2 , 8 , 3 , 9 )),
133- )
134- a1 = randn (elt1, map (i -> dims[i], d1s))
135- labels1 = map (i -> labels[i], d1s)
136- a2 = randn (elt2, map (i -> dims[i], d2s))
137- labels2 = map (i -> labels[i], d2s)
138- labels_dest = map (i -> labels[i], d_dests)
139-
140- # Don't specify destination labels
141- a_dest, labels_dest′ = TensorAlgebra. contract (a1, labels1, a2, labels2)
142- a_dest_tensoroperations = TensorOperations. tensorcontract (
143- labels_dest′, a1, labels1, a2, labels2
113+ # # Right now TensorOperations version is downgraded when using cuTENSOR to `v0.7` we
114+ # # are waiting for TensorOperations to support the breaking changes in cuTENSOR 2.x
115+ if ! (" cutensor" ∈ ARGS )
116+ using TensorOperations: TensorOperations
117+ @testset " contract (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts, elt2 in elts
118+ dims = (2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 )
119+ labels = (:a , :b , :c , :d , :e , :f , :g , :h , :i )
120+ for (d1s, d2s, d_dests) in (
121+ ((1 , 2 ), (1 , 2 ), ()),
122+ ((1 , 2 ), (2 , 1 ), ()),
123+ ((1 , 2 ), (2 , 3 ), (1 , 3 )),
124+ ((1 , 2 ), (2 , 3 ), (3 , 1 )),
125+ ((2 , 1 ), (2 , 3 ), (3 , 1 )),
126+ ((1 , 2 , 3 ), (2 , 3 , 4 ), (1 , 4 )),
127+ ((1 , 2 , 3 ), (2 , 3 , 4 ), (4 , 1 )),
128+ ((3 , 2 , 1 ), (4 , 2 , 3 ), (4 , 1 )),
129+ ((1 , 2 , 3 ), (3 , 4 ), (1 , 2 , 4 )),
130+ ((1 , 2 , 3 ), (3 , 4 ), (4 , 1 , 2 )),
131+ ((1 , 2 , 3 ), (3 , 4 ), (2 , 4 , 1 )),
132+ ((3 , 1 , 2 ), (3 , 4 ), (2 , 4 , 1 )),
133+ ((3 , 2 , 1 ), (4 , 3 ), (2 , 4 , 1 )),
134+ ((1 , 2 , 3 , 4 , 5 , 6 ), (4 , 5 , 6 , 7 , 8 , 9 ), (1 , 2 , 3 , 7 , 8 , 9 )),
135+ ((2 , 4 , 5 , 1 , 6 , 3 ), (6 , 4 , 9 , 8 , 5 , 7 ), (1 , 7 , 2 , 8 , 3 , 9 )),
144136 )
145- @test a_dest ≈ a_dest_tensoroperations
137+ a1 = randn (elt1, map (i -> dims[i], d1s))
138+ labels1 = map (i -> labels[i], d1s)
139+ a2 = randn (elt2, map (i -> dims[i], d2s))
140+ labels2 = map (i -> labels[i], d2s)
141+ labels_dest = map (i -> labels[i], d_dests)
146142
147- # Specify destination labels
148- a_dest = TensorAlgebra. contract (labels_dest, a1, labels1, a2, labels2)
149- a_dest_tensoroperations = TensorOperations. tensorcontract (
150- labels_dest, a1, labels1, a2, labels2
151- )
152- @test a_dest ≈ a_dest_tensoroperations
143+ # Don't specify destination labels
144+ a_dest, labels_dest′ = TensorAlgebra. contract (a1, labels1, a2, labels2)
145+ a_dest_tensoroperations = TensorOperations. tensorcontract (
146+ labels_dest′ , a1, labels1, a2, labels2
147+ )
148+ @test a_dest ≈ a_dest_tensoroperations
153149
154- # Specify α and β
155- elt_dest = promote_type (elt1, elt2)
156- # TODO : Using random `α`, `β` causing
157- # random test failures, investigate why.
158- α = elt_dest (1.2 ) # randn(elt_dest)
159- β = elt_dest (2.4 ) # randn(elt_dest)
160- a_dest_init = randn (elt_dest, map (i -> dims[i], d_dests))
161- a_dest = copy (a_dest_init)
162- TensorAlgebra. contract! (a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
163- a_dest_tensoroperations = TensorOperations. tensorcontract (
164- labels_dest, a1, labels1, a2, labels2
165- )
166- # # Here we loosened the tolerance because of some floating point roundoff issue.
167- # # with Float32 numbers
168- @test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol =
169- 10 * default_rtol (elt_dest)
150+ # Specify destination labels
151+ a_dest = TensorAlgebra. contract (labels_dest, a1, labels1, a2, labels2)
152+ a_dest_tensoroperations = TensorOperations. tensorcontract (
153+ labels_dest, a1, labels1, a2, labels2
154+ )
155+ @test a_dest ≈ a_dest_tensoroperations
156+
157+ # Specify α and β
158+ elt_dest = promote_type (elt1, elt2)
159+ # TODO : Using random `α`, `β` causing
160+ # random test failures, investigate why.
161+ α = elt_dest (1.2 ) # randn(elt_dest)
162+ β = elt_dest (2.4 ) # randn(elt_dest)
163+ a_dest_init = randn (elt_dest, map (i -> dims[i], d_dests))
164+ a_dest = copy (a_dest_init)
165+ TensorAlgebra. contract! (a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
166+ a_dest_tensoroperations = TensorOperations. tensorcontract (
167+ labels_dest, a1, labels1, a2, labels2
168+ )
169+ # # Here we loosened the tolerance because of some floating point roundoff issue.
170+ # # with Float32 numbers
171+ @test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol =
172+ 10 * default_rtol (elt_dest)
173+ end
170174 end
171175 end
172176 @testset " qr (eltype=$elt )" for elt in elts
182186 @test a ≈ a′
183187 end
184188end
189+
185190end
0 commit comments