Skip to content

Commit 542dfd9

Browse files
authored
Add trivial bipermutation interface to matricized factorizations and functions (#98)
1 parent 097ee39 commit 542dfd9

File tree

8 files changed

+152
-98
lines changed

8 files changed

+152
-98
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.1"
4+
version = "0.5.2"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/MatrixAlgebra.jl

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,34 @@ for (svd, svd_trunc, svd_full, svd_compact) in (
7575
(:svd, :svd_trunc, :svd_full, :svd_compact),
7676
(:svd!, :svd_trunc!, :svd_full!, :svd_compact!),
7777
)
78+
_svd = Symbol(:_, svd)
7879
@eval begin
79-
function $svd(A::AbstractMatrix; full::Bool = false, trunc = nothing, kwargs...)
80-
return if !isnothing(trunc)
81-
@assert !full "Specified both full and truncation, currently not supported"
82-
$svd_trunc(A; trunc, kwargs...)
83-
else
84-
(full ? $svd_full : $svd_compact)(A; kwargs...)
85-
end
80+
function $svd(
81+
A::AbstractMatrix;
82+
full::Union{Bool, Val} = Val(false),
83+
trunc = nothing,
84+
kwargs...,
85+
)
86+
return $_svd(full, trunc, A; kwargs...)
87+
end
88+
function $_svd(full::Bool, trunc, A::AbstractMatrix; kwargs...)
89+
return $_svd(Val(full), trunc, A; kwargs...)
90+
end
91+
function $_svd(full::Val{false}, trunc::Nothing, A::AbstractMatrix; kwargs...)
92+
return $svd_compact(A; kwargs...)
93+
end
94+
function $_svd(full::Val{false}, trunc, A::AbstractMatrix; kwargs...)
95+
return $svd_trunc(A; trunc, kwargs...)
96+
end
97+
function $_svd(full::Val{true}, trunc::Nothing, A::AbstractMatrix; kwargs...)
98+
return $svd_full(A; kwargs...)
99+
end
100+
function $_svd(full::Val{true}, trunc, A::AbstractMatrix; kwargs...)
101+
return throw(
102+
ArgumentError(
103+
"Specified both full and truncation, currently not supported"
104+
)
105+
)
86106
end
87107
end
88108
end

src/factorizations.jl

Lines changed: 48 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,38 @@ for f in (
77
@eval begin
88
function $f(
99
A::AbstractArray,
10-
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
10+
codomain_length::Val, domain_length::Val;
1111
kwargs...,
1212
)
1313
# tensor to matrix
14-
A_mat = matricize(A, codomain_perm, domain_perm)
14+
A_mat = matricize(A, codomain_length, domain_length)
1515

1616
# factorization
1717
X, Y = MatrixAlgebra.$f(A_mat; kwargs...)
1818

1919
# matrix to tensor
20-
biperm = permmortar((codomain_perm, domain_perm))
20+
biperm = blockedtrivialperm((codomain_length, domain_length))
2121
axes_codomain, axes_domain = blocks(axes(A)[biperm])
2222
axes_X = tuplemortar((axes_codomain, (axes(X, 2),)))
2323
axes_Y = tuplemortar(((axes(Y, 1),), axes_domain))
2424
return unmatricize(X, axes_X), unmatricize(Y, axes_Y)
2525
end
26+
end
27+
end
28+
29+
for f in (
30+
:qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize,
31+
:eigen, :eigvals, :svd, :svdvals, :left_null, :right_null,
32+
)
33+
@eval begin
34+
function $f(
35+
A::AbstractArray,
36+
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
37+
kwargs...,
38+
)
39+
A_perm = bipermutedims(A, codomain_perm, domain_perm)
40+
return $f(A_perm, Val(length(codomain_perm)), Val(length(domain_perm)); kwargs...)
41+
end
2642
function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
2743
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
2844
return $f(A, blocks(biperm)...; kwargs...)
@@ -36,6 +52,7 @@ end
3652
"""
3753
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R
3854
qr(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> Q, R
55+
qr(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> Q, R
3956
qr(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Q, R
4057
4158
Compute the QR decomposition of a generic N-dimensional array, by interpreting it as
@@ -55,6 +72,7 @@ qr
5572
"""
5673
lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q
5774
lq(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> L, Q
75+
lq(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> L, Q
5876
lq(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> L, Q
5977
6078
Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as
@@ -74,6 +92,7 @@ lq
7492
"""
7593
left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P
7694
left_polar(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> W, P
95+
left_polar(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> W, P
7796
left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> W, P
7897
7998
Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as
@@ -91,6 +110,7 @@ left_polar
91110
"""
92111
right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W
93112
right_polar(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> P, W
113+
right_polar(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> P, W
94114
right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> P, W
95115
96116
Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as
@@ -108,6 +128,7 @@ right_polar
108128
"""
109129
left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C
110130
left_orth(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> V, C
131+
left_orth(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> V, C
111132
left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> V, C
112133
113134
Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as
@@ -125,6 +146,7 @@ left_orth
125146
"""
126147
right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V
127148
right_orth(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> C, V
149+
right_orth(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> C, V
128150
right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> C, V
129151
130152
Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as
@@ -142,6 +164,7 @@ right_orth
142164
"""
143165
factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y
144166
factorize(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> X, Y
167+
factorize(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> X, Y
145168
factorize(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y
146169
147170
Compute the decomposition of a generic N-dimensional array, by interpreting it as
@@ -159,6 +182,7 @@ factorize
159182
"""
160183
eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V
161184
eigen(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> D, V
185+
eigen(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> D, V
162186
eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D, V
163187
164188
Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as
@@ -175,26 +199,18 @@ their labels or directly through a bi-permutation.
175199
See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlgebraKit.eig_vals!`,
176200
`MatrixAlgebraKit.eigh_full!`, `MatrixAlgebraKit.eigh_trunc!`, and `MatrixAlgebraKit.eigh_vals!`.
177201
"""
178-
function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
179-
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
180-
return eigen(A, blocks(biperm)...; kwargs...)
181-
end
182-
function eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...)
183-
return eigen(A, blocks(biperm)...; kwargs...)
184-
end
185202
function eigen(
186203
A::AbstractArray,
187-
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
204+
codomain_length::Val, domain_length::Val;
188205
kwargs...,
189206
)
190207
# tensor to matrix
191-
A_mat = matricize(A, codomain_perm, domain_perm)
192-
208+
A_mat = matricize(A, codomain_length, domain_length)
193209
# factorization
194210
D, V = MatrixAlgebra.eigen!(A_mat; kwargs...)
195211

196212
# matrix to tensor
197-
biperm = permmortar((codomain_perm, domain_perm))
213+
biperm = blockedtrivialperm((codomain_length, domain_length))
198214
axes_codomain, = blocks(axes(A)[biperm])
199215
axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),)))
200216
return D, unmatricize(V, axes_V)
@@ -203,6 +219,7 @@ end
203219
"""
204220
eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D
205221
eigvals(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> D
222+
eigvals(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> D
206223
eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D
207224
208225
Compute the eigenvalues of a generic N-dimensional array, by interpreting it as
@@ -217,25 +234,19 @@ their labels or directly through a bi-permutation. The output is a vector of eig
217234
218235
See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`.
219236
"""
220-
function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
221-
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
222-
return eigvals(A, blocks(biperm)...; kwargs...)
223-
end
224-
function eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...)
225-
return eigvals(A, blocks(biperm)...; kwargs...)
226-
end
227237
function eigvals(
228238
A::AbstractArray,
229-
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
239+
codomain_length::Val, domain_length::Val;
230240
kwargs...,
231241
)
232-
A_mat = matricize(A, codomain_perm, domain_perm)
242+
A_mat = matricize(A, codomain_length, domain_length)
233243
return MatrixAlgebra.eigvals!(A_mat; kwargs...)
234244
end
235245

236246
"""
237247
svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ
238248
svd(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> U, S, Vᴴ
249+
svd(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> U, S, Vᴴ
239250
svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> U, S, Vᴴ
240251
241252
Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as
@@ -251,26 +262,18 @@ their labels or directly through a bi-permutation.
251262
252263
See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `MatrixAlgebraKit.svd_trunc!`.
253264
"""
254-
function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
255-
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
256-
return svd(A, blocks(biperm)...; kwargs...)
257-
end
258-
function svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...)
259-
return svd(A, blocks(biperm)...; kwargs...)
260-
end
261265
function svd(
262266
A::AbstractArray,
263-
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
267+
codomain_length::Val, domain_length::Val;
264268
kwargs...,
265269
)
266270
# tensor to matrix
267-
A_mat = matricize(A, codomain_perm, domain_perm)
268-
271+
A_mat = matricize(A, codomain_length, domain_length)
269272
# factorization
270273
U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...)
271274

272275
# matrix to tensor
273-
biperm = permmortar((codomain_perm, domain_perm))
276+
biperm = blockedtrivialperm((codomain_length, domain_length))
274277
axes_codomain, axes_domain = blocks(axes(A)[biperm])
275278
axes_U = tuplemortar((axes_codomain, (axes(U, 2),)))
276279
axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain))
@@ -280,6 +283,7 @@ end
280283
"""
281284
svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) -> S
282285
svdvals(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}) -> S
286+
svdvals(A::AbstractArray, codomain_length::Val, domain_length::Val) -> S
283287
svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> S
284288
285289
Compute the singular values of a generic N-dimensional array, by interpreting it as
@@ -288,24 +292,18 @@ their labels or directly through a bi-permutation. The output is a vector of sin
288292
289293
See also `MatrixAlgebraKit.svd_vals!`.
290294
"""
291-
function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain)
292-
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
293-
return svdvals(A, blocks(biperm)...)
294-
end
295-
function svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2})
296-
return svdvals(A, blocks(biperm)...)
297-
end
298295
function svdvals(
299296
A::AbstractArray,
300-
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}
297+
codomain_length::Val, domain_length::Val
301298
)
302-
A_mat = matricize(A, codomain_perm, domain_perm)
299+
A_mat = matricize(A, codomain_length, domain_length)
303300
return MatrixAlgebra.svdvals!(A_mat)
304301
end
305302

306303
"""
307304
left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> N
308305
left_null(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> N
306+
left_null(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> N
309307
left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> N
310308
311309
Compute the left nullspace of a generic N-dimensional array, by interpreting it as
@@ -321,21 +319,14 @@ The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`.
321319
The options are `:qr`, `:qrpos` and `:svd`. The former two require `0 == atol == rtol`.
322320
The default is `:qrpos` if `atol == rtol == 0`, and `:svd` otherwise.
323321
"""
324-
function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
325-
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
326-
return left_null(A, blocks(biperm)...; kwargs...)
327-
end
328-
function left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...)
329-
return left_null(A, blocks(biperm)...; kwargs...)
330-
end
331322
function left_null(
332323
A::AbstractArray,
333-
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
324+
codomain_length::Val, domain_length::Val;
334325
kwargs...,
335326
)
336-
A_mat = matricize(A, codomain_perm, domain_perm)
327+
A_mat = matricize(A, codomain_length, domain_length)
337328
N = MatrixAlgebraKit.left_null!(A_mat; kwargs...)
338-
biperm = permmortar((codomain_perm, domain_perm))
329+
biperm = blockedtrivialperm((codomain_length, domain_length))
339330
axes_codomain = first(blocks(axes(A)[biperm]))
340331
axes_N = tuplemortar((axes_codomain, (axes(N, 2),)))
341332
return unmatricize(N, axes_N)
@@ -344,6 +335,7 @@ end
344335
"""
345336
right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Nᴴ
346337
right_null(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> Nᴴ
338+
right_null(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> Nᴴ
347339
right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Nᴴ
348340
349341
Compute the right nullspace of a generic N-dimensional array, by interpreting it as
@@ -359,21 +351,14 @@ The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`.
359351
The options are `:lq`, `:lqpos` and `:svd`. The former two require `0 == atol == rtol`.
360352
The default is `:lqpos` if `atol == rtol == 0`, and `:svd` otherwise.
361353
"""
362-
function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
363-
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
364-
return right_null(A, blocks(biperm)...; kwargs...)
365-
end
366-
function right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...)
367-
return right_null(A, blocks(biperm)...; kwargs...)
368-
end
369354
function right_null(
370355
A::AbstractArray,
371-
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
356+
codomain_length::Val, domain_length::Val;
372357
kwargs...,
373358
)
374-
A_mat = matricize(A, codomain_perm, domain_perm)
359+
A_mat = matricize(A, codomain_length, domain_length)
375360
Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...)
376-
biperm = permmortar((codomain_perm, domain_perm))
361+
biperm = blockedtrivialperm((codomain_length, domain_length))
377362
axes_domain = last(blocks((axes(A)[biperm])))
378363
axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain))
379364
return unmatricize(Nᴴ, axes_Nᴴ)

0 commit comments

Comments
 (0)