11using EllipsisNotation: var".."
22using LinearAlgebra: norm, qr
3- using TensorAlgebra: TensorAlgebra, fusedims, splitdims
4- default_rtol (elt:: Type ) = 10 ^ (0.75 * log10 (eps (real (elt))))
3+ using StableRNGs: StableRNG
4+ using TensorAlgebra: contract, contract!, fusedims, splitdims
5+ using TensorOperations: TensorOperations
56using Test: @test , @test_broken , @testset
7+
8+ default_rtol (elt:: Type ) = 10 ^ (0.75 * log10 (eps (real (elt))))
69const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
710
811@testset " TensorAlgebra" begin
@@ -90,14 +93,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
9093 labels_dest = map (i -> labels[i], d_dests)
9194
9295 # Don't specify destination labels
93- a_dest, labels_dest′ = TensorAlgebra . contract (a1, labels1, a2, labels2)
96+ a_dest, labels_dest′ = contract (a1, labels1, a2, labels2)
9497 a_dest_tensoroperations = TensorOperations. tensorcontract (
9598 labels_dest′, a1, labels1, a2, labels2
9699 )
97100 @test a_dest ≈ a_dest_tensoroperations
98101
99102 # Specify destination labels
100- a_dest = TensorAlgebra . contract (labels_dest, a1, labels1, a2, labels2)
103+ a_dest = contract (labels_dest, a1, labels1, a2, labels2)
101104 a_dest_tensoroperations = TensorOperations. tensorcontract (
102105 labels_dest, a1, labels1, a2, labels2
103106 )
@@ -111,7 +114,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
111114 β = elt_dest (2.4 ) # randn(elt_dest)
112115 a_dest_init = randn (elt_dest, map (i -> dims[i], d_dests))
113116 a_dest = copy (a_dest_init)
114- TensorAlgebra . contract! (a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
117+ contract! (a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
115118 a_dest_tensoroperations = TensorOperations. tensorcontract (
116119 labels_dest, a1, labels1, a2, labels2
117120 )
@@ -124,28 +127,90 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
124127 @testset " outer product contraction (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts,
125128 elt2 in elts
126129
127- a1 = randn (elt1, 2 , 3 )
128- a2 = randn (elt2, 4 , 5 )
129-
130130 elt_dest = promote_type (elt1, elt2)
131131
132- a_dest, labels = TensorAlgebra. contract (a1, (" i" , " j" ), a2, (" k" , " l" ))
132+ rng = StableRNG (123 )
133+ a1 = randn (rng, elt1, 2 , 3 )
134+ a2 = randn (rng, elt2, 4 , 5 )
135+
136+ a_dest, labels = contract (a1, (" i" , " j" ), a2, (" k" , " l" ))
133137 @test labels == (" i" , " j" , " k" , " l" )
134138 @test eltype (a_dest) === elt_dest
135139 @test a_dest ≈ reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... ))
136140
137- a_dest = TensorAlgebra . contract ((" i" , " k" , " j" , " l" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
141+ a_dest = contract ((" i" , " k" , " j" , " l" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
138142 @test eltype (a_dest) === elt_dest
139143 @test a_dest ≈ permutedims (
140144 reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... )), (1 , 3 , 2 , 4 )
141145 )
142146
143147 a_dest = zeros (elt_dest, 2 , 5 , 3 , 4 )
144- TensorAlgebra . contract! (a_dest, (" i" , " l" , " j" , " k" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
148+ contract! (a_dest, (" i" , " l" , " j" , " k" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
145149 @test a_dest ≈ permutedims (
146150 reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... )), (1 , 4 , 2 , 3 )
147151 )
148152 end
153+ @testset " scalar contraction (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts,
154+ elt2 in elts
155+
156+ elt_dest = promote_type (elt1, elt2)
157+
158+ rng = StableRNG (123 )
159+ a = randn (rng, elt1, (2 , 3 , 4 , 5 ))
160+ s = randn (rng, elt2, ())
161+ t = randn (rng, elt2, ())
162+
163+ labels_a = (" i" , " j" , " k" , " l" )
164+
165+ # Array-scalar contraction.
166+ a_dest, labels_dest = contract (a, labels_a, s, ())
167+ @test labels_dest == labels_a
168+ @test a_dest ≈ a * s[]
169+
170+ # Scalar-array contraction.
171+ a_dest, labels_dest = contract (s, (), a, labels_a)
172+ @test labels_dest == labels_a
173+ @test a_dest ≈ a * s[]
174+
175+ # Scalar-scalar contraction.
176+ a_dest, labels_dest = contract (s, (), t, ())
177+ @test labels_dest == ()
178+ @test a_dest[] ≈ s[] * t[]
179+
180+ # Specify output labels.
181+ labels_dest_example = (" j" , " l" , " i" , " k" )
182+ size_dest_example = (3 , 5 , 2 , 4 )
183+
184+ # Array-scalar contraction.
185+ a_dest = contract (labels_dest_example, a, labels_a, s, ())
186+ @test size (a_dest) == size_dest_example
187+ @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
188+
189+ # Scalar-array contraction.
190+ a_dest = contract (labels_dest_example, s, (), a, labels_a)
191+ @test size (a_dest) == size_dest_example
192+ @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
193+
194+ # Scalar-scalar contraction.
195+ a_dest = contract ((), s, (), t, ())
196+ @test size (a_dest) == ()
197+ @test a_dest[] ≈ s[] * t[]
198+
199+ # Array-scalar contraction.
200+ a_dest = zeros (elt_dest, size_dest_example)
201+ contract! (a_dest, labels_dest_example, a, labels_a, s, ())
202+ @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
203+
204+ # Scalar-array contraction.
205+ a_dest = zeros (elt_dest, size_dest_example)
206+ contract! (a_dest, labels_dest_example, s, (), a, labels_a)
207+ @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
208+
209+ # Scalar-scalar contraction.
210+ a_dest = zeros (elt_dest, ())
211+ contract! (a_dest, (), s, (), t, ())
212+ @test a_dest[] ≈ s[] * t[]
213+ end
149214end
150215@testset " qr (eltype=$elt )" for elt in elts
151216 a = randn (elt, 5 , 4 , 3 , 2 )
154219 labels_r = (:d , :c )
155220 q, r = qr (a, labels_a, labels_q, labels_r)
156221 label_qr = :qr
157- a′ = TensorAlgebra. contract (
158- labels_a, q, (labels_q... , label_qr), r, (label_qr, labels_r... )
159- )
222+ a′ = contract (labels_a, q, (labels_q... , label_qr), r, (label_qr, labels_r... ))
160223 @test a ≈ a′
161224end
0 commit comments