Skip to content

Commit 495a4f0

Browse files
committed
Refactor matricize/unmatricize
1 parent 5c7a9bb commit 495a4f0

File tree

6 files changed

+164
-52
lines changed

6 files changed

+164
-52
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
3-
version = "0.4.6"
3+
version = "0.4.7"
44
authors = ["ITensor developers <[email protected]> and contributors"]
55

66
[deps]

src/blockedpermutation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function istrivialperm(t::Tuple)
88
return t == trivialperm(length(t))
99
end
1010

11-
value(::Val{N}) where {N} = N
11+
unval(::Val{N}) where {N} = N
1212

1313
_flatten_tuples(t::Tuple) = t
1414
function _flatten_tuples(t1::Tuple, t2::Tuple, trest::Tuple...)
@@ -87,7 +87,7 @@ function blockedpermvcat(
8787
end
8888

8989
function blockedpermvcat(len::Val, permblocks::Tuple{Vararg{Int}}...)
90-
value(len) != sum(length.(permblocks); init = 0) &&
90+
unval(len) != sum(length.(permblocks); init = 0) &&
9191
throw(ArgumentError("Invalid total length"))
9292
return permmortar(Tuple(permblocks))
9393
end
@@ -97,7 +97,7 @@ function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
9797
end
9898

9999
function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}})
100-
return value(vallength)
100+
return unval(vallength)
101101
end
102102

103103
# blockedpermvcat((4, 3), .., 1) == blockedpermvcat((4, 3), (2,), (1,))
@@ -199,8 +199,11 @@ end
199199

200200
blockedperm(tp::BlockedTrivialPermutation) = tp
201201

202+
function blockedtrivialperm(blocklengths::Tuple{Vararg{Val}})
203+
return BlockedTrivialPermutation{length(blocklengths), unval.(blocklengths)}()
204+
end
202205
function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}})
203-
return BlockedTrivialPermutation{length(blocklengths), blocklengths}()
206+
return blockedtrivialperm(Val.(blocklengths))
204207
end
205208

206209
function trivialperm(blockedperm::AbstractBlockTuple)

src/contract/contract.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,28 @@ abstract type Algorithm end
44

55
Algorithm(alg::Algorithm) = alg
66

7-
struct Matricize <: Algorithm end
7+
struct Matricize{Style} <: Algorithm
8+
fusion_style::Style
9+
end
810

9-
default_contract_alg() = Matricize()
11+
function default_contract_alg(a1::AbstractArray, labels1, a2::AbstractArray, labels2)
12+
style1 = FusionStyle(a1)
13+
style2 = FusionStyle(a2)
14+
style1 == style2 || error("Styles must match.")
15+
return Matricize(style1)
16+
end
17+
function default_contractadd!_alg(
18+
a_dest::AbstractArray, labels_dest,
19+
a1::AbstractArray, labels1,
20+
a2::AbstractArray, labels2,
21+
α::Number, β::Number,
22+
)
23+
style_dest = FusionStyle(a_dest)
24+
style1 = FusionStyle(a1)
25+
style2 = FusionStyle(a2)
26+
style_dest == style1 == style2 || error("Styles must match.")
27+
return Matricize(style_dest)
28+
end
1029

1130
# Required interface if not using
1231
# matricized contraction.
@@ -29,7 +48,7 @@ function contract(
2948
labels1,
3049
a2::AbstractArray,
3150
labels2;
32-
alg = default_contract_alg(),
51+
alg = default_contract_alg(a1, labels1, a2, labels2),
3352
kwargs...,
3453
)
3554
return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...)
@@ -48,7 +67,7 @@ function contract(
4867
labels1,
4968
a2::AbstractArray,
5069
labels2;
51-
alg = default_contract_alg(),
70+
alg = default_contract_alg(a1, labels1, a2, labels2),
5271
kwargs...,
5372
)
5473
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...)
@@ -75,7 +94,7 @@ function contractadd!(
7594
labels2,
7695
α::Number,
7796
β::Number;
78-
alg = default_contract_alg(),
97+
alg = default_contractadd!_alg(a_dest, labels_dest, a1, labels1, a2, labels2, α, β),
7998
kwargs...,
8099
)
81100
contractadd!(
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using LinearAlgebra: mul!
22

33
function contractadd!(
4-
::Matricize,
4+
alg::Matricize,
55
a_dest::AbstractArray,
66
biperm_dest::AbstractBlockPermutation{2},
77
a1::AbstractArray,
@@ -12,11 +12,10 @@ function contractadd!(
1212
β::Number,
1313
)
1414
invbiperm = biperm(invperm(biperm_dest), length_codomain(biperm1))
15-
1615
check_input(contract, a_dest, invbiperm, a1, biperm1, a2, biperm2)
17-
a1_mat = matricize(a1, biperm1)
18-
a2_mat = matricize(a2, biperm2)
16+
a1_mat = matricize(alg.fusion_style, a1, biperm1)
17+
a2_mat = matricize(alg.fusion_style, a2, biperm2)
1918
a_dest_mat = a1_mat * a2_mat
20-
unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β)
19+
unmatricizeadd!(alg.fusion_style, a_dest, a_dest_mat, invbiperm, α, β)
2120
return a_dest
2221
end

src/matricize.jl

Lines changed: 125 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,77 +21,146 @@ trivial_axis(::Tuple{}) = Base.OneTo(1)
2121
trivial_axis(::Tuple{Vararg{AbstractUnitRange}}) = Base.OneTo(1)
2222
trivial_axis(::Tuple{Vararg{AbstractBlockedUnitRange}}) = blockedrange([1])
2323

24+
# Inner version takes a list of sub-permutations, overload this one if needed.
2425
function fuseaxes(
25-
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation
26+
axes::Tuple{Vararg{AbstractUnitRange}}, lengths::Val...
2627
)
27-
axesblocks = blocks(axes[blockedperm])
28+
axesblocks = blocks(axes[blockedtrivialperm(lengths)])
2829
return map(block -> isempty(block) ? trivial_axis(axes) : (block...), axesblocks)
2930
end
3031

32+
# Inner version takes a list of sub-permutations, overload this one if needed.
33+
function fuseaxes(
34+
axes::Tuple{Vararg{AbstractUnitRange}}, permblocks::Tuple{Vararg{Int}}...
35+
)
36+
axes′ = map(d -> axes[d], permmortar(permblocks))
37+
return fuseaxes(axes′, Val.(length.(permblocks))...)
38+
end
39+
40+
function fuseaxes(
41+
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation
42+
)
43+
return fuseaxes(axes, blocks(blockedperm)...)
44+
end
45+
46+
# Inner version takes a list of sub-permutations, overload this one if needed.
47+
function permuteblockeddims(a::AbstractArray, perm1, perm2)
48+
return _permutedims(a, (perm1..., perm2...))
49+
end
50+
function permuteblockeddims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2)
51+
return _permutedims!(a_dest, a_src, (perm1..., perm2...))
52+
end
53+
3154
# TODO remove _permutedims once support for Julia 1.10 is dropped
3255
# define permutedims with a BlockedPermuation. Default is to flatten it.
33-
function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation)
34-
return _permutedims(a, Tuple(biperm))
56+
function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation{2})
57+
return permuteblockeddims(a, blocks(biperm)...)
3558
end
36-
3759
function permuteblockeddims!(
38-
a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation
60+
a_dest::AbstractArray, a_src::AbstractArray, biperm::AbstractBlockPermutation{2}
3961
)
40-
return _permutedims!(a, b, Tuple(biperm))
62+
return permuteblockeddims!(a_dest, a_src, blocks(biperm)...)
4163
end
4264

4365
# ===================================== matricize ========================================
4466
# TBD settle copy/not copy convention
4567
# matrix factorizations assume copy
4668
# maybe: copy=false kwarg
4769

48-
function matricize(a::AbstractArray, biperm_dest::AbstractBlockPermutation{2})
49-
ndims(a) == length(biperm_dest) || throw(ArgumentError("Invalid bipermutation"))
50-
return matricize(FusionStyle(a), a, biperm_dest)
70+
function matricize(a::AbstractArray, length1::Val, length2::Val)
71+
return matricize(FusionStyle(a), a, length1, length2)
72+
end
73+
# This is the primary function that should be overloaded for new fusion styles.
74+
# This assumes the permutation was already performed.
75+
function matricize(style::FusionStyle, a::AbstractArray, length1::Val, length2::Val)
76+
return throw(
77+
MethodError(
78+
matricize, Tuple{typeof(style), typeof(a), typeof(length1), typeof(length2)}
79+
)
80+
)
5181
end
5282

5383
function matricize(
54-
style::FusionStyle, a::AbstractArray, biperm_dest::AbstractBlockPermutation{2}
84+
a::AbstractArray, permblock1::Tuple{Vararg{Int}}, permblock2::Tuple{Vararg{Int}}
5585
)
56-
a_perm = permuteblockeddims(a, biperm_dest)
57-
return matricize(style, a_perm, trivialperm(biperm_dest))
86+
return matricize(FusionStyle(a), a, permblock1, permblock2)
5887
end
59-
88+
# This is a more advanced version to overload where the permutation is actually performed.
6089
function matricize(
61-
style::FusionStyle, a::AbstractArray, biperm_dest::BlockedTrivialPermutation{2}
90+
style::FusionStyle, a::AbstractArray,
91+
permblock1::NTuple{N1, Int}, permblock2::NTuple{N2, Int}
92+
) where {N1, N2}
93+
ndims(a) == length(permblock1) + length(permblock2) ||
94+
throw(ArgumentError("Invalid bipermutation"))
95+
a_perm = permuteblockeddims(a, permblock1, permblock2)
96+
return matricize(style, a_perm, Val(length(permblock1)), Val(length(permblock2)))
97+
end
98+
99+
# Process inputs such as `EllipsisNotation.Ellipsis`.
100+
function to_permblocks(a::AbstractArray, permblocks::NTuple{2, Tuple{Vararg{Int}}})
101+
isperm((permblocks[1]..., permblocks[2]...)) ||
102+
throw(ArgumentError("Invalid bipermutation"))
103+
return permblocks
104+
end
105+
# Like `setcomplement` is like `setdiff` but assumes t2 ⊆ t1.
106+
function tuplesetcomplement(t1::NTuple{N1}, t2::NTuple{N2}) where {N1, N2}
107+
t2 t1 || throw(ArgumentError("t2 must be a subset of t1"))
108+
return NTuple{N1 - N2}(setdiff(t1, t2))
109+
end
110+
function to_permblocks(
111+
a::AbstractArray, permblocks::Tuple{Tuple{Ellipsis}, Tuple{Vararg{Int}}}
112+
)
113+
permblocks1 = tuplesetcomplement(ntuple(identity, ndims(a)), permblocks[2])
114+
return (permblocks1, permblocks[2])
115+
end
116+
function to_permblocks(
117+
a::AbstractArray, permblocks::Tuple{Tuple{Vararg{Int}}, Tuple{Ellipsis}}
62118
)
63-
return throw(MethodError(matricize, Tuple{typeof(style), typeof(a), typeof(biperm_dest)}))
119+
permblocks2 = tuplesetcomplement(ntuple(identity, ndims(a)), permblocks[1])
120+
return (permblocks[1], permblocks2)
121+
end
122+
function matricize(a::AbstractArray, permblock1, permblock2)
123+
return matricize(FusionStyle(a), a, permblock1, permblock2)
124+
end
125+
function matricize(style::FusionStyle, a::AbstractArray, permblock1, permblock2)
126+
return matricize(style, a, to_permblocks(a, (permblock1, permblock2))...)
64127
end
65128

66-
# default is reshape
129+
function matricize(a::AbstractArray, biperm_dest::AbstractBlockPermutation{2})
130+
return matricize(FusionStyle(a), a, biperm_dest)
131+
end
67132
function matricize(
68-
::ReshapeFusion, a::AbstractArray, biperm_dest::BlockedTrivialPermutation{2}
133+
style::FusionStyle, a::AbstractArray, biperm_dest::AbstractBlockPermutation{2}
69134
)
70-
new_axes = fuseaxes(axes(a), biperm_dest)
71-
return reshape(a, new_axes...)
135+
return matricize(style, a, blocks(biperm_dest)...)
72136
end
73137

74-
function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple)
75-
return matricize(a, blockedpermvcat(permblock1, permblock2; length = Val(ndims(a))))
138+
# default is reshape
139+
function matricize(::ReshapeFusion, a::AbstractArray, length1::Val, length2::Val)
140+
return reshape(a, fuseaxes(axes(a), length1, length2)...)
76141
end
77142

78143
# ==================================== unmatricize =======================================
79144
function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2})
80-
length(axes_dest) == length(invbiperm) ||
81-
throw(ArgumentError("axes do not match permutation"))
82145
return unmatricize(FusionStyle(m), m, axes_dest, invbiperm)
83146
end
84-
85147
function unmatricize(
86-
::FusionStyle, m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2}
148+
style::FusionStyle, m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2}
87149
)
150+
length(axes_dest) == length(invbiperm) ||
151+
throw(ArgumentError("axes do not match permutation"))
88152
blocked_axes = axes_dest[invbiperm]
89-
a12 = unmatricize(m, blocked_axes)
153+
a12 = unmatricize(style, m, blocked_axes)
90154
biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest))
91-
92155
return permuteblockeddims(a12, biperm_dest)
93156
end
94157

158+
function unmatricize(
159+
m::AbstractMatrix,
160+
blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}},
161+
)
162+
return unmatricize(FusionStyle(m), m, blocked_axes)
163+
end
95164
function unmatricize(
96165
::ReshapeFusion,
97166
m::AbstractMatrix,
@@ -100,30 +169,49 @@ function unmatricize(
100169
return reshape(m, Tuple(blocked_axes)...)
101170
end
102171

103-
function unmatricize(m::AbstractMatrix, blocked_axes)
104-
return unmatricize(FusionStyle(m), m, blocked_axes)
105-
end
106-
107172
function unmatricize(
108173
m::AbstractMatrix,
109174
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
110175
domain_axes::Tuple{Vararg{AbstractUnitRange}},
111176
)
177+
return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes)
178+
end
179+
function unmatricize(
180+
style::FusionStyle, m::AbstractMatrix,
181+
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
182+
domain_axes::Tuple{Vararg{AbstractUnitRange}},
183+
)
112184
blocked_axes = tuplemortar((codomain_axes, domain_axes))
113-
return unmatricize(m, blocked_axes)
185+
return unmatricize(style, m, blocked_axes)
114186
end
115187

116-
function unmatricize!(a_dest, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2})
188+
function unmatricize!(
189+
a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2}
190+
)
191+
return unmatricize!(FusionStyle(m), a_dest, m, invbiperm)
192+
end
193+
function unmatricize!(
194+
style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix,
195+
invbiperm::AbstractBlockPermutation{2},
196+
)
117197
ndims(a_dest) == length(invbiperm) ||
118198
throw(ArgumentError("destination does not match permutation"))
119199
blocked_axes = axes(a_dest)[invbiperm]
120-
a_perm = unmatricize(m, blocked_axes)
200+
a_perm = unmatricize(style, m, blocked_axes)
121201
biperm_dest = biperm(invperm(invbiperm), length_codomain(axes(a_dest)))
122202
return permuteblockeddims!(a_dest, a_perm, biperm_dest)
123203
end
124204

125-
function unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β)
126-
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
205+
function unmatricizeadd!(
206+
a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2},
207+
α::Number, β::Number
208+
)
209+
return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invbiperm, α, β)
210+
end
211+
function unmatricizeadd!(
212+
style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, invbiperm, α, β
213+
)
214+
a12 = unmatricize(style, m, axes(a_dest), invbiperm)
127215
a_dest .= α .* a12 .+ β .* a_dest
128216
return a_dest
129217
end

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1414
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1515
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1616

17+
[sources]
18+
TensorAlgebra = {path = ".."}
19+
1720
[compat]
1821
Aqua = "0.8.9"
1922
BlockArrays = "1.6.1"

0 commit comments

Comments
 (0)