Skip to content

Commit c4de4ac

Browse files
committed
remove label_dest
1 parent bd74481 commit c4de4ac

File tree

5 files changed

+55
-149
lines changed

5 files changed

+55
-149
lines changed

src/contract/allocate_output.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ function check_input(::typeof(contract), a1, labels1, a2, labels2)
77
throw(ArgumentError("Invalid permutation for right tensor"))
88
end
99

10-
function check_input(::typeof(contract), a_dest, labels_dest, a1, labels1, a2, labels2)
11-
ndims(a_dest) == length(labels_dest) ||
10+
function check_input(::typeof(contract), a_dest, labels_out, a1, labels1, a2, labels2)
11+
ndims(a_dest) == length(labels_out) ||
1212
throw(ArgumentError("Invalid permutation for destination tensor"))
1313
return check_input(contract, a1, labels1, a2, labels2)
1414
end
@@ -17,7 +17,6 @@ end
1717
# i.e. `ContractAdd`?
1818
function output_axes(
1919
::typeof(contract),
20-
biperm_dest::AbstractBlockPermutation{2},
2120
a1::AbstractArray,
2221
biperm1::AbstractBlockPermutation{2},
2322
a2::AbstractArray,
@@ -26,24 +25,22 @@ function output_axes(
2625
)
2726
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
2827
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
28+
biperm_out = blockedtrivialperm((length(biperm1[Block(1)]), length(biperm2[Block(2)])))
2929
@assert axes_contracted == axes_contracted2
30-
return genperm((axes_codomain..., axes_domain...), invperm(Tuple(biperm_dest)))
30+
return genperm((axes_codomain..., axes_domain...), Tuple(biperm_out))
3131
end
3232

3333
# TODO: Use `ArrayLayouts`-like `MulAdd` object,
3434
# i.e. `ContractAdd`?
3535
function allocate_output(
3636
::typeof(contract),
37-
biperm_dest::AbstractBlockPermutation,
3837
a1::AbstractArray,
3938
biperm1::AbstractBlockPermutation,
4039
a2::AbstractArray,
4140
biperm2::AbstractBlockPermutation,
4241
α::Number=one(Bool),
4342
)
4443
check_input(contract, a1, biperm1, a2, biperm2)
45-
blocklengths(biperm_dest) == (length(biperm1[Block(1)]), length(biperm2[Block(2)])) ||
46-
throw(ArgumentError("Invalid permutation for destination tensor"))
47-
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
44+
axes_dest = output_axes(contract, a1, biperm1, a2, biperm2, α)
4845
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
4946
end

src/contract/blockedperms.jl

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,29 @@
11
using .BaseExtensions: BaseExtensions
22

3-
function blockedperms(
4-
f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2
5-
)
6-
return blockedperms(f, dimnames_dest, dimnames1, dimnames2)
3+
function blockedperms(f::typeof(contract), ::Algorithm, dimnames1, dimnames2)
4+
return blockedperms(f, dimnames1, dimnames2)
75
end
86

97
# codomain <-- domain
10-
function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
11-
dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2)))
8+
function blockedperms(::typeof(contract), dimnames1, dimnames2)
9+
dimnames = collect(Iterators.flatten((dimnames1, dimnames2)))
1210
for i in unique(dimnames)
13-
count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels"))
11+
count(==(i), dimnames) in (1, 2) || throw(ArgumentError("Invalid contraction labels"))
1412
end
1513

1614
codomain = Tuple(setdiff(dimnames1, dimnames2))
1715
contracted = Tuple(intersect(dimnames1, dimnames2))
1816
domain = Tuple(setdiff(dimnames2, dimnames1))
1917

20-
perm_codomain_dest = BaseExtensions.indexin(codomain, dimnames_dest)
21-
perm_domain_dest = BaseExtensions.indexin(domain, dimnames_dest)
22-
2318
perm_codomain1 = BaseExtensions.indexin(codomain, dimnames1)
2419
perm_domain1 = BaseExtensions.indexin(contracted, dimnames1)
2520

2621
perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2)
2722
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)
2823

29-
permblocks_dest = (perm_codomain_dest, perm_domain_dest)
30-
biperm_dest = blockedpermvcat(permblocks_dest...)
3124
permblocks1 = (perm_codomain1, perm_domain1)
3225
biperm1 = blockedpermvcat(permblocks1...)
3326
permblocks2 = (perm_codomain2, perm_domain2)
3427
biperm2 = blockedpermvcat(permblocks2...)
35-
return biperm_dest, biperm1, biperm2
28+
return biperm1, biperm2
3629
end

src/contract/contract.jl

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ default_contract_alg() = Matricize()
1313
function contract!(
1414
alg::Algorithm,
1515
a_dest::AbstractArray,
16-
biperm_dest::AbstractBlockPermutation,
1716
a1::AbstractArray,
1817
biperm1::AbstractBlockPermutation,
1918
a2::AbstractArray,
@@ -36,35 +35,8 @@ function contract(
3635
return contract(Algorithm(alg), a1, labels1, a2, labels2, α; kwargs...)
3736
end
3837

39-
function contract(
40-
alg::Algorithm,
41-
a1::AbstractArray,
42-
labels1,
43-
a2::AbstractArray,
44-
labels2,
45-
α::Number=one(Bool);
46-
kwargs...,
47-
)
48-
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α; kwargs...)
49-
return contract(alg, labels_dest, a1, labels1, a2, labels2, α; kwargs...), labels_dest
50-
end
51-
52-
function contract(
53-
labels_dest,
54-
a1::AbstractArray,
55-
labels1,
56-
a2::AbstractArray,
57-
labels2,
58-
α::Number=one(Bool);
59-
alg=default_contract_alg(),
60-
kwargs...,
61-
)
62-
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2, α; kwargs...)
63-
end
64-
6538
function contract!(
6639
a_dest::AbstractArray,
67-
labels_dest,
6840
a1::AbstractArray,
6941
labels1,
7042
a2::AbstractArray,
@@ -74,13 +46,12 @@ function contract!(
7446
alg=default_contract_alg(),
7547
kwargs...,
7648
)
77-
contract!(Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...)
49+
contract!(Algorithm(alg), a_dest, a1, labels1, a2, labels2, α, β; kwargs...)
7850
return a_dest
7951
end
8052

8153
function contract(
8254
alg::Algorithm,
83-
labels_dest,
8455
a1::AbstractArray,
8556
labels1,
8657
a2::AbstractArray,
@@ -89,14 +60,14 @@ function contract(
8960
kwargs...,
9061
)
9162
check_input(contract, a1, labels1, a2, labels2)
92-
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
93-
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...)
63+
biperm1, biperm2 = blockedperms(contract, labels1, labels2)
64+
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α; kwargs...)
65+
return contract(alg, a1, biperm1, a2, biperm2, α; kwargs...), labels_dest
9466
end
9567

9668
function contract!(
9769
alg::Algorithm,
9870
a_dest::AbstractArray,
99-
labels_dest,
10071
a1::AbstractArray,
10172
labels1,
10273
a2::AbstractArray,
@@ -105,14 +76,13 @@ function contract!(
10576
β::Number;
10677
kwargs...,
10778
)
108-
check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2)
109-
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
110-
return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
79+
check_input(contract, a1, labels1, a2, labels2)
80+
biperm1, biperm2 = blockedperms(contract, labels1, labels2)
81+
return contract!(alg, a_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
11182
end
11283

11384
function contract(
11485
alg::Algorithm,
115-
biperm_dest::AbstractBlockPermutation,
11686
a1::AbstractArray,
11787
biperm1::AbstractBlockPermutation,
11888
a2::AbstractArray,
@@ -121,7 +91,7 @@ function contract(
12191
kwargs...,
12292
)
12393
check_input(contract, a1, biperm1, a2, biperm2)
124-
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
125-
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
94+
a_dest = allocate_output(contract, a1, biperm1, a2, biperm2, α)
95+
contract!(alg, a_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
12696
return a_dest
12797
end

src/contract/contract_matricize/contract.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@ using LinearAlgebra: mul!
33
function contract!(
44
::Matricize,
55
a_dest::AbstractArray,
6-
biperm_dest::AbstractBlockPermutation{2},
76
a1::AbstractArray,
87
biperm1::AbstractBlockPermutation{2},
98
a2::AbstractArray,
109
biperm2::AbstractBlockPermutation{2},
1110
α::Number,
1211
β::Number,
1312
)
14-
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
15-
a_dest_mat = matricize(a_dest, biperm_dest)
13+
biperm_out = blockedtrivialperm((length(biperm1[Block(1)]), length(biperm2[Block(2)])))
14+
check_input(contract, a_dest, biperm_out, a1, biperm1, a2, biperm2)
15+
a_dest_mat = matricize(a_dest, biperm_out)
1616
a1_mat = matricize(a1, biperm1)
1717
a2_mat = matricize(a2, biperm2)
1818
mul!(a_dest_mat, a1_mat, a2_mat, α, β)
19-
unmatricize!(a_dest, a_dest_mat, biperm_dest)
19+
unmatricize!(a_dest, a_dest_mat, biperm_out)
2020
return a_dest
2121
end

test/test_basics.jl

Lines changed: 31 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Random: randn!
12
using Test: @test, @test_broken, @test_throws, @testset
23

34
using EllipsisNotation: var".."
@@ -134,41 +135,33 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
134135
elt_dest = promote_type(elt1, elt2)
135136
a1 = ones(elt1, (1, 1))
136137
a2 = ones(elt2, (1, 1))
137-
a_dest = ones(elt_dest, (1, 1))
138+
a_dest = ones(elt_dest, (1, 1, 1))
138139
@test_throws ArgumentError contract(a1, (1, 2, 4), a2, (2, 3))
139140
@test_throws ArgumentError contract(a1, (1, 2), a2, (2, 3, 4))
140-
@test_throws ArgumentError contract((1, 3, 4), a1, (1, 2), a2, (2, 3))
141-
@test_throws ArgumentError contract((1, 3), a1, (1, 2), a2, (2, 4))
142-
@test_throws ArgumentError contract!(a_dest, (1, 3, 4), a1, (1, 2), a2, (2, 3))
141+
@test_throws ArgumentError contract!(a_dest, a1, (1, 2), a2, (2, 3))
143142

144143
dims = (2, 3, 4, 5, 6, 7, 8, 9, 10)
145144
labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i)
146-
for (d1s, d2s, d_dests) in (
147-
((1, 2), (1, 2), ()),
148-
((1, 2), (2, 1), ()),
149-
((1, 2), (2, 1, 3), (3,)),
150-
((1, 2, 3), (2, 1), (3,)),
151-
((1, 2), (2, 3), (1, 3)),
152-
((1, 2), (2, 3), (3, 1)),
153-
((2, 1), (2, 3), (3, 1)),
154-
((1, 2, 3), (2, 3, 4), (1, 4)),
155-
((1, 2, 3), (2, 3, 4), (4, 1)),
156-
((3, 2, 1), (4, 2, 3), (4, 1)),
157-
((1, 2, 3), (3, 4), (1, 2, 4)),
158-
((1, 2, 3), (3, 4), (4, 1, 2)),
159-
((1, 2, 3), (3, 4), (2, 4, 1)),
160-
((3, 1, 2), (3, 4), (2, 4, 1)),
161-
((3, 2, 1), (4, 3), (2, 4, 1)),
162-
((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)),
163-
((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)),
145+
for (d1s, d2s) in (
146+
((1, 2), (1, 2)),
147+
((1, 2), (2, 1)),
148+
((1, 2), (2, 1, 3)),
149+
((1, 2, 3), (2, 1)),
150+
((1, 2), (2, 3)),
151+
((2, 1), (2, 3)),
152+
((1, 2, 3), (2, 3, 4)),
153+
((3, 2, 1), (4, 2, 3)),
154+
((1, 2, 3), (3, 4)),
155+
((3, 1, 2), (3, 4)),
156+
((3, 2, 1), (4, 3)),
157+
((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9)),
158+
((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7)),
164159
)
165160
a1 = randn(elt1, map(i -> dims[i], d1s))
166161
labels1 = map(i -> labels[i], d1s)
167162
a2 = randn(elt2, map(i -> dims[i], d2s))
168163
labels2 = map(i -> labels[i], d2s)
169-
labels_dest = map(i -> labels[i], d_dests)
170164

171-
# Don't specify destination labels
172165
a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
173166
@test labels_dest′ isa
174167
BlockedTuple{2,(length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))}
@@ -177,35 +170,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
177170
)
178171
@test a_dest a_dest_tensoroperations
179172

180-
# Specify destination labels
181-
a_dest = contract(labels_dest, a1, labels1, a2, labels2)
182-
a_dest_tensoroperations = TensorOperations.tensorcontract(
183-
labels_dest, a1, labels1, a2, labels2
184-
)
185-
@test a_dest a_dest_tensoroperations
186-
187-
# Specify with bituple
188-
a_dest = contract(tuplemortar((labels_dest, ())), a1, labels1, a2, labels2)
189-
@test a_dest a_dest_tensoroperations
190-
a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2)
191-
@test a_dest a_dest_tensoroperations
192-
a_dest = contract(labels_dest′, a1, labels1, a2, labels2)
193-
a_dest_tensoroperations = TensorOperations.tensorcontract(
194-
Tuple(labels_dest′), a1, labels1, a2, labels2
195-
)
196-
@test a_dest a_dest_tensoroperations
197-
198173
# Specify α and β
199174
# TODO: Using random `α`, `β` causing
200175
# random test failures, investigate why.
201176
α = elt_dest(1.2) # randn(elt_dest)
202177
β = elt_dest(2.4) # randn(elt_dest)
203-
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
204-
a_dest = copy(a_dest_init)
205-
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
206-
a_dest_tensoroperations = TensorOperations.tensorcontract(
207-
labels_dest, a1, labels1, a2, labels2
208-
)
178+
randn!(a_dest)
179+
a_dest_init = copy(a_dest)
180+
contract!(a_dest, a1, labels1, a2, labels2, α, β)
181+
a_dest_tensoroperations = TensorOperations.tensorcontract(a1, labels1, a2, labels2)
209182
## Here we loosened the tolerance because of some floating point roundoff issue.
210183
## with Float32 numbers
211184
@test a_dest α * a_dest_tensoroperations + β * a_dest_init rtol =
@@ -226,17 +199,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
226199
@test eltype(a_dest) === elt_dest
227200
@test a_dest reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))
228201

229-
a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
230-
@test eltype(a_dest) === elt_dest
231-
@test a_dest permutedims(
232-
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4)
233-
)
234-
235-
a_dest = zeros(elt_dest, 2, 5, 3, 4)
236-
contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
237-
@test a_dest permutedims(
238-
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3)
239-
)
202+
a_dest = zeros(elt_dest, 2, 3, 4, 5)
203+
contract!(a_dest, a1, ("i", "j"), a2, ("k", "l"))
204+
@test a_dest reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))
240205
end
241206
@testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
242207
elt2 in elts
@@ -265,38 +230,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
265230
@test labels_dest == tuplemortar(((), ()))
266231
@test a_dest[] s[] * t[]
267232

268-
# Specify output labels.
269-
labels_dest_example = ("j", "l", "i", "k")
270-
size_dest_example = (3, 5, 2, 4)
271-
272-
# Array-scalar contraction.
273-
a_dest = contract(labels_dest_example, a, labels_a, s, ())
274-
@test size(a_dest) == size_dest_example
275-
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]
276-
277-
# Scalar-array contraction.
278-
a_dest = contract(labels_dest_example, s, (), a, labels_a)
279-
@test size(a_dest) == size_dest_example
280-
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]
281-
282-
# Scalar-scalar contraction.
283-
a_dest = contract((), s, (), t, ())
284-
@test size(a_dest) == ()
285-
@test a_dest[] s[] * t[]
286-
287233
# Array-scalar contraction.
288-
a_dest = zeros(elt_dest, size_dest_example)
289-
contract!(a_dest, labels_dest_example, a, labels_a, s, ())
290-
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]
234+
a_dest = zeros(elt_dest, size(a))
235+
contract!(a_dest, a, (1, 2, 3, 4), s, ())
236+
@test a_dest a * s[]
291237

292238
# Scalar-array contraction.
293-
a_dest = zeros(elt_dest, size_dest_example)
294-
contract!(a_dest, labels_dest_example, s, (), a, labels_a)
295-
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]
239+
a_dest = zeros(elt_dest, size(a))
240+
contract!(a_dest, s, (), a, (1, 2, 3, 4))
241+
@test a_dest a * s[]
296242

297243
# Scalar-scalar contraction.
298244
a_dest = zeros(elt_dest, ())
299-
contract!(a_dest, (), s, (), t, ())
245+
contract!(a_dest, s, (), t, ())
300246
@test a_dest[] s[] * t[]
301247
end
302248
end

0 commit comments

Comments
 (0)