Skip to content

Commit 43553e3

Browse files
authored
More factorizations like polar, orth, etc. (#51)
1 parent 4b0e372 commit 43553e3

File tree

3 files changed

+258
-7
lines changed

3 files changed

+258
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/factorizations.jl

Lines changed: 170 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using LinearAlgebra: LinearAlgebra
12
using MatrixAlgebraKit:
23
eig_full!,
34
eig_trunc!,
@@ -6,16 +7,19 @@ using MatrixAlgebraKit:
67
eigh_trunc!,
78
eigh_vals!,
89
left_null!,
10+
left_orth!,
11+
left_polar!,
912
lq_full!,
1013
lq_compact!,
1114
qr_full!,
1215
qr_compact!,
1316
right_null!,
17+
right_orth!,
18+
right_polar!,
1419
svd_full!,
1520
svd_compact!,
1621
svd_trunc!,
1722
svd_vals!
18-
using LinearAlgebra: LinearAlgebra
1923

2024
"""
2125
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R
@@ -76,7 +80,7 @@ function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, k
7680
A_mat = fusedims(A, biperm)
7781

7882
# factorization
79-
L, Q = full ? lq_full!(A_mat; kwargs...) : lq_compact!(A_mat; kwargs...)
83+
L, Q = (full ? lq_full! : lq_compact!)(A_mat; kwargs...)
8084

8185
# matrix to tensor
8286
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
@@ -120,11 +124,12 @@ function eigen(
120124
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat)
121125

122126
# factorization
123-
if !isnothing(trunc)
124-
D, V = (ishermitian ? eigh_trunc! : eig_trunc!)(A_mat; trunc, kwargs...)
127+
f! = if !isnothing(trunc)
128+
ishermitian ? eigh_trunc! : eig_trunc!
125129
else
126-
D, V = (ishermitian ? eigh_full! : eig_full!)(A_mat; kwargs...)
130+
ishermitian ? eigh_full! : eig_full!
127131
end
132+
D, V = f!(A_mat; kwargs...)
128133

129134
# matrix to tensor
130135
axes_codomain, = blockpermute(axes(A), biperm)
@@ -284,3 +289,163 @@ function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
284289
axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...)
285290
return splitdims(Nᴴ, axes_Nᴴ)
286291
end
292+
293+
"""
294+
left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P
295+
left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P
296+
297+
Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as
298+
a linear map from the domain to the codomain indices. These can be specified either via
299+
their labels, or directly through a `biperm`.
300+
301+
## Keyword arguments
302+
303+
- Keyword arguments are passed on directly to MatrixAlgebraKit.
304+
305+
See also `MatrixAlgebraKit.left_polar!`.
306+
"""
307+
function left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
308+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
309+
return left_polar(A, biperm; kwargs...)
310+
end
311+
function left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
312+
# tensor to matrix
313+
A_mat = fusedims(A, biperm)
314+
315+
# factorization
316+
W, P = left_polar!(A_mat; kwargs...)
317+
318+
# matrix to tensor
319+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
320+
axes_W = (axes_codomain..., axes(W, 2))
321+
axes_P = (axes(P, 1), axes_domain...)
322+
return splitdims(W, axes_W), splitdims(P, axes_P)
323+
end
324+
325+
"""
326+
right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W
327+
right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W
328+
329+
Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as
330+
a linear map from the domain to the codomain indices. These can be specified either via
331+
their labels, or directly through a `biperm`.
332+
333+
## Keyword arguments
334+
335+
- Keyword arguments are passed on directly to MatrixAlgebraKit.
336+
337+
See also `MatrixAlgebraKit.right_polar!`.
338+
"""
339+
function right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
340+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
341+
return right_polar(A, biperm; kwargs...)
342+
end
343+
function right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
344+
# tensor to matrix
345+
A_mat = fusedims(A, biperm)
346+
347+
# factorization
348+
P, W = right_polar!(A_mat; kwargs...)
349+
350+
# matrix to tensor
351+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
352+
axes_P = (axes_codomain..., axes(P, ndims(P)))
353+
axes_W = (axes(W, 1), axes_domain...)
354+
return splitdims(P, axes_P), splitdims(W, axes_W)
355+
end
356+
357+
"""
358+
left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C
359+
left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C
360+
361+
Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as
362+
a linear map from the domain to the codomain indices. These can be specified either via
363+
their labels, or directly through a `biperm`.
364+
365+
## Keyword arguments
366+
367+
- Keyword arguments are passed on directly to MatrixAlgebraKit.
368+
369+
See also `MatrixAlgebraKit.left_orth!`.
370+
"""
371+
function left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
372+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
373+
return left_orth(A, biperm; kwargs...)
374+
end
375+
function left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
376+
# tensor to matrix
377+
A_mat = fusedims(A, biperm)
378+
379+
# factorization
380+
V, C = left_orth!(A_mat; kwargs...)
381+
382+
# matrix to tensor
383+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
384+
axes_V = (axes_codomain..., axes(V, 2))
385+
axes_C = (axes(C, 1), axes_domain...)
386+
return splitdims(V, axes_V), splitdims(C, axes_C)
387+
end
388+
389+
"""
390+
right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V
391+
right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V
392+
393+
Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as
394+
a linear map from the domain to the codomain indices. These can be specified either via
395+
their labels, or directly through a `biperm`.
396+
397+
## Keyword arguments
398+
399+
- Keyword arguments are passed on directly to MatrixAlgebraKit.
400+
401+
See also `MatrixAlgebraKit.right_orth!`.
402+
"""
403+
function right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
404+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
405+
return right_orth(A, biperm; kwargs...)
406+
end
407+
function right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
408+
# tensor to matrix
409+
A_mat = fusedims(A, biperm)
410+
411+
# factorization
412+
P, W = right_orth!(A_mat; kwargs...)
413+
414+
# matrix to tensor
415+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
416+
axes_P = (axes_codomain..., axes(P, ndims(P)))
417+
axes_W = (axes(W, 1), axes_domain...)
418+
return splitdims(P, axes_P), splitdims(W, axes_W)
419+
end
420+
421+
"""
422+
factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y
423+
factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y
424+
425+
Compute the decomposition of a generic N-dimensional array, by interpreting it as
426+
a linear map from the domain to the codomain indices. These can be specified either via
427+
their labels, or directly through a `biperm`.
428+
429+
## Keyword arguments
430+
431+
- `orth::Symbol=:left`: specify the orthogonality of the decomposition.
432+
Currently only `:left` and `:right` are supported.
433+
- Other keywords are passed on directly to MatrixAlgebraKit.
434+
"""
435+
function factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
436+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
437+
return factorize(A, biperm; kwargs...)
438+
end
439+
function factorize(A::AbstractArray, biperm::BlockedPermutation{2}; orth=:left, kwargs...)
440+
# tensor to matrix
441+
A_mat = fusedims(A, biperm)
442+
443+
# factorization
444+
X, Y = (orth == :left ? left_orth! : right_orth!)(A_mat; kwargs...)
445+
446+
# matrix to tensor
447+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
448+
axes_X = (axes_codomain..., axes(X, ndims(X)))
449+
axes_Y = (axes(Y, 1), axes_domain...)
450+
return splitdims(X, axes_X), splitdims(Y, axes_Y)
451+
end

test/test_factorizations.jl

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11
using Test: @test, @testset, @inferred
22
using TestExtras: @constinferred
33
using TensorAlgebra:
4-
TensorAlgebra, contract, lq, qr, svd, svdvals, eigen, eigvals, left_null, right_null
4+
TensorAlgebra,
5+
contract,
6+
eigen,
7+
eigvals,
8+
factorize,
9+
left_null,
10+
left_orth,
11+
left_polar,
12+
lq,
13+
qr,
14+
right_null,
15+
right_orth,
16+
right_polar,
17+
svd,
18+
svdvals
519
using MatrixAlgebraKit: truncrank
620
using LinearAlgebra: LinearAlgebra, norm, diag
721

@@ -194,3 +208,75 @@ end
194208
@test norm(AN) 0 atol = 1e-14
195209
NN = contract((:n, :n′), Nᴴ, (:n, labels_domain...), Nᴴ, (:n′, labels_domain...))
196210
end
211+
212+
@testset "Left polar ($T)" for T in elts
213+
A = randn(T, 2, 2, 2, 2)
214+
labels_A = (:a, :b, :c, :d)
215+
labels_W = (:b, :a)
216+
labels_P = (:d, :c)
217+
218+
Acopy = deepcopy(A)
219+
W, P = left_polar(A, labels_A, labels_W, labels_P)
220+
@test A == Acopy # should not have altered initial array
221+
A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...))
222+
@test A A′
223+
@test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
224+
end
225+
226+
@testset "Right polar ($T)" for T in elts
227+
A = randn(T, 2, 2, 2, 2)
228+
labels_A = (:a, :b, :c, :d)
229+
labels_P = (:b, :a)
230+
labels_W = (:d, :c)
231+
232+
Acopy = deepcopy(A)
233+
P, W = right_polar(A, labels_A, labels_P, labels_W)
234+
@test A == Acopy # should not have altered initial array
235+
A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...))
236+
@test A A′
237+
@test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
238+
end
239+
240+
@testset "Left orth ($T)" for T in elts
241+
A = randn(T, 2, 2, 2, 2)
242+
labels_A = (:a, :b, :c, :d)
243+
labels_W = (:b, :a)
244+
labels_P = (:d, :c)
245+
246+
Acopy = deepcopy(A)
247+
W, P = left_orth(A, labels_A, labels_W, labels_P)
248+
@test A == Acopy # should not have altered initial array
249+
A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...))
250+
@test A A′
251+
@test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
252+
end
253+
254+
@testset "Right orth ($T)" for T in elts
255+
A = randn(T, 2, 2, 2, 2)
256+
labels_A = (:a, :b, :c, :d)
257+
labels_P = (:b, :a)
258+
labels_W = (:d, :c)
259+
260+
Acopy = deepcopy(A)
261+
P, W = right_orth(A, labels_A, labels_P, labels_W)
262+
@test A == Acopy # should not have altered initial array
263+
A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...))
264+
@test A A′
265+
@test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
266+
end
267+
268+
@testset "factorize ($T)" for T in elts
269+
A = randn(T, 2, 2, 2, 2)
270+
labels_A = (:a, :b, :c, :d)
271+
labels_X = (:b, :a)
272+
labels_Y = (:d, :c)
273+
274+
Acopy = deepcopy(A)
275+
for orth in (:left, :right)
276+
X, Y = factorize(A, labels_A, labels_X, labels_Y; orth)
277+
@test A == Acopy # should not have altered initial array
278+
A′ = contract(labels_A, X, (labels_X..., :x), Y, (:x, labels_Y...))
279+
@test A A′
280+
@test size(X, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
281+
end
282+
end

0 commit comments

Comments
 (0)