Skip to content

Commit d3512a4

Browse files
authored
use BlockedTuple for default label_dest (#64)
1 parent 1e90db0 commit d3512a4

File tree

8 files changed

+67
-9
lines changed

8 files changed

+67
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.7"
4+
version = "0.3.8"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/contract/allocate_output.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
using Base.PermutedDimsArrays: genperm
22

3+
function check_input(::typeof(contract), a1, labels1, a2, labels2)
4+
ndims(a1) == length(labels1) ||
5+
throw(ArgumentError("Invalid permutation for left tensor"))
6+
return ndims(a2) == length(labels2) ||
7+
throw(ArgumentError("Invalid permutation for right tensor"))
8+
end
9+
10+
function check_input(::typeof(contract), a_dest, labels_dest, a1, labels1, a2, labels2)
11+
ndims(a_dest) == length(labels_dest) ||
12+
throw(ArgumentError("Invalid permutation for destination tensor"))
13+
return check_input(contract, a1, labels1, a2, labels2)
14+
end
15+
316
# TODO: Use `ArrayLayouts`-like `MulAdd` object,
417
# i.e. `ContractAdd`?
518
function output_axes(
@@ -28,6 +41,9 @@ function allocate_output(
2841
biperm2::AbstractBlockPermutation,
2942
α::Number=one(Bool),
3043
)
44+
check_input(contract, a1, biperm1, a2, biperm2)
45+
blocklengths(biperm_dest) == (length(biperm1[Block(1)]), length(biperm2[Block(2)])) ||
46+
throw(ArgumentError("Invalid permutation for destination tensor"))
3147
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
3248
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
3349
end

src/contract/blockedperms.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ end
88

99
# codomain <-- domain
1010
function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
11+
dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2)))
12+
for i in unique(dimnames)
13+
count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels"))
14+
end
15+
1116
codomain = Tuple(setdiff(dimnames1, dimnames2))
1217
contracted = Tuple(intersect(dimnames1, dimnames2))
1318
domain = Tuple(setdiff(dimnames2, dimnames1))

src/contract/contract.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ function contract(
8888
α::Number=one(Bool);
8989
kwargs...,
9090
)
91+
check_input(contract, a1, labels1, a2, labels2)
9192
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
9293
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...)
9394
end
@@ -104,6 +105,7 @@ function contract!(
104105
β::Number;
105106
kwargs...,
106107
)
108+
check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2)
107109
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
108110
return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
109111
end
@@ -118,6 +120,7 @@ function contract(
118120
α::Number;
119121
kwargs...,
120122
)
123+
check_input(contract, a1, biperm1, a2, biperm2)
121124
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
122125
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
123126
return a_dest

src/contract/contract_matricize/contract.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function contract!(
1111
α::Number,
1212
β::Number,
1313
)
14+
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
1415
a_dest_mat = matricize(a_dest, biperm_dest)
1516
a1_mat = matricize(a1, biperm1)
1617
a2_mat = matricize(a2, biperm2)

src/contract/output_labels.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ function output_labels(
1010
return output_labels(f, alg, labels1, labels2)
1111
end
1212

13-
function output_labels(f::typeof(contract), alg::Algorithm, labels1, labels2)
13+
function output_labels(f::typeof(contract), ::Algorithm, labels1, labels2)
1414
return output_labels(f, labels1, labels2)
1515
end
1616

1717
function output_labels(::typeof(contract), labels1, labels2)
18-
return Tuple(symdiff(labels1, labels2))
18+
diff1 = Tuple(setdiff(labels1, labels2))
19+
diff2 = Tuple(setdiff(labels2, labels1))
20+
return tuplemortar((diff1, diff2))
1921
end

src/matricize.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ end
4646
# maybe: copy=false kwarg
4747

4848
function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2})
49+
ndims(a) == length(biperm) || throw(ArgumentError("Invalid bipermutation"))
4950
return matricize(FusionStyle(a), a, biperm)
5051
end
5152

@@ -78,6 +79,7 @@ function unmatricize(
7879
axes::Tuple{Vararg{AbstractUnitRange}},
7980
biperm::AbstractBlockPermutation{2},
8081
)
82+
length(axes) == length(biperm) || throw(ArgumentError("axes do not match permutation"))
8183
return unmatricize(FusionStyle(m), m, axes, biperm)
8284
end
8385

@@ -122,6 +124,8 @@ end
122124
function unmatricize!(
123125
a::AbstractArray, m::AbstractMatrix, biperm::AbstractBlockPermutation{2}
124126
)
127+
ndims(a) == length(biperm) ||
128+
throw(ArgumentError("destination does not match permutation"))
125129
blocked_axes = axes(a)[biperm]
126130
a_perm = unmatricize(m, blocked_axes)
127131
return permuteblockeddims!(a, a_perm, invperm(biperm))

test/test_basics.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using StableRNGs: StableRNG
55
using TensorOperations: TensorOperations
66

77
using TensorAlgebra:
8+
BlockedTuple,
89
blockedpermvcat,
910
permuteblockeddims,
1011
permuteblockeddims!,
@@ -61,6 +62,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
6162

6263
@test_throws MethodError matricize(a, (1, 2), (3,), (4,))
6364
@test_throws MethodError matricize(a, (1, 2, 3, 4))
65+
@test_throws ArgumentError matricize(a, blockedpermvcat((1, 2), (3,)))
6466

6567
v = ones(elt, 2)
6668
a_fused = matricize(v, (1,), ())
@@ -122,10 +124,23 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
122124
a = unmatricize(m, (), ())
123125
@test a isa Array{elt,0}
124126
@test a[] == m[1, 1]
127+
128+
@test_throws ArgumentError unmatricize(m, (), blockedpermvcat((1, 2), (3,)))
129+
@test_throws ArgumentError unmatricize!(m, m, blockedpermvcat((1, 2), (3,)))
125130
end
126131

127132
using TensorOperations: TensorOperations
128133
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
134+
elt_dest = promote_type(elt1, elt2)
135+
a1 = ones(elt1, (1, 1))
136+
a2 = ones(elt2, (1, 1))
137+
a_dest = ones(elt_dest, (1, 1))
138+
@test_throws ArgumentError contract(a1, (1, 2, 4), a2, (2, 3))
139+
@test_throws ArgumentError contract(a1, (1, 2), a2, (2, 3, 4))
140+
@test_throws ArgumentError contract((1, 3, 4), a1, (1, 2), a2, (2, 3))
141+
@test_throws ArgumentError contract((1, 3), a1, (1, 2), a2, (2, 4))
142+
@test_throws ArgumentError contract!(a_dest, (1, 3, 4), a1, (1, 2), a2, (2, 3))
143+
129144
dims = (2, 3, 4, 5, 6, 7, 8, 9, 10)
130145
labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i)
131146
for (d1s, d2s, d_dests) in (
@@ -155,8 +170,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
155170

156171
# Don't specify destination labels
157172
a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
173+
@test labels_dest′ isa
174+
BlockedTuple{2,(length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))}
158175
a_dest_tensoroperations = TensorOperations.tensorcontract(
159-
labels_dest′, a1, labels1, a2, labels2
176+
Tuple(labels_dest′), a1, labels1, a2, labels2
160177
)
161178
@test a_dest a_dest_tensoroperations
162179

@@ -167,8 +184,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
167184
)
168185
@test a_dest a_dest_tensoroperations
169186

187+
# Specify with bituple
188+
a_dest = contract(tuplemortar((labels_dest, ())), a1, labels1, a2, labels2)
189+
@test a_dest a_dest_tensoroperations
190+
a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2)
191+
@test a_dest a_dest_tensoroperations
192+
a_dest = contract(labels_dest′, a1, labels1, a2, labels2)
193+
a_dest_tensoroperations = TensorOperations.tensorcontract(
194+
Tuple(labels_dest′), a1, labels1, a2, labels2
195+
)
196+
@test a_dest a_dest_tensoroperations
197+
170198
# Specify α and β
171-
elt_dest = promote_type(elt1, elt2)
172199
# TODO: Using random `α`, `β` causing
173200
# random test failures, investigate why.
174201
α = elt_dest(1.2) # randn(elt_dest)
@@ -195,7 +222,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
195222
a2 = randn(rng, elt2, 4, 5)
196223

197224
a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l"))
198-
@test labels == ("i", "j", "k", "l")
225+
@test labels == tuplemortar((("i", "j"), ("k", "l")))
199226
@test eltype(a_dest) === elt_dest
200227
@test a_dest reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))
201228

@@ -225,17 +252,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
225252

226253
# Array-scalar contraction.
227254
a_dest, labels_dest = contract(a, labels_a, s, ())
228-
@test labels_dest == labels_a
255+
@test labels_dest == tuplemortar((labels_a, ()))
229256
@test a_dest a * s[]
230257

231258
# Scalar-array contraction.
232259
a_dest, labels_dest = contract(s, (), a, labels_a)
233-
@test labels_dest == labels_a
260+
@test labels_dest == tuplemortar(((), labels_a))
234261
@test a_dest a * s[]
235262

236263
# Scalar-scalar contraction.
237264
a_dest, labels_dest = contract(s, (), t, ())
238-
@test labels_dest == ()
265+
@test labels_dest == tuplemortar(((), ()))
239266
@test a_dest[] s[] * t[]
240267

241268
# Specify output labels.

0 commit comments

Comments
 (0)