-
Notifications
You must be signed in to change notification settings - Fork 2
[WIP] Add tensor factorizations through MatrixAlgebraKit #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
9561b1d
6aeaa5b
dc27256
7428023
7cb6538
1deecbf
14ad07f
125577a
b550e24
5e4213c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,83 @@ | ||
using ArrayLayouts: LayoutMatrix | ||
using LinearAlgebra: LinearAlgebra, Diagonal | ||
|
||
function qr(a::AbstractArray, biperm::BlockedPermutation{2}) | ||
a_matricized = fusedims(a, biperm) | ||
# TODO: Make this more generic, allow choosing thin or full, | ||
# make sure this works on GPU. | ||
q_fact, r_matricized = LinearAlgebra.qr(a_matricized) | ||
q_matricized = typeof(a_matricized)(q_fact) | ||
axes_codomain, axes_domain = blockpermute(axes(a), biperm) | ||
axes_q = (axes_codomain..., axes(q_matricized, 2)) | ||
axes_r = (axes(r_matricized, 1), axes_domain...) | ||
q = splitdims(q_matricized, axes_q) | ||
r = splitdims(r_matricized, axes_r) | ||
return q, r | ||
using MatrixAlgebraKit: | ||
MatrixAlgebraKit, qr_full, qr_compact, svd_full, svd_compact, svd_trunc | ||
# TODO: consider in-place version | ||
# TODO: figure out kwargs and document | ||
# | ||
""" | ||
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; full=true, kwargs...) | ||
qr(A::AbstractArray, biperm::BlockedPermutation{2}; full=true, kwargs...) | ||
Compute the QR decomposition of a generic N-dimensional array, by interpreting it as | ||
a linear map from the domain to the codomain indices. These can be specified either via | ||
their labels, or directly through a `biperm`. | ||
""" | ||
function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) | ||
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) | ||
return qr(A, biperm) | ||
end | ||
function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=true, kwargs...) | ||
# tensor to matrix | ||
A_mat = fusedims(A, biperm) | ||
|
||
# factorization | ||
Q, R = full ? qr_full(A_mat; kwargs...) : qr_compact(A_mat; kwargs...) | ||
|
||
# matrix to tensor | ||
axes_codomain, axes_domain = blockpermute(axes(A), biperm) | ||
axes_Q = (axes_codomain..., axes(Q, 2)) | ||
axes_R = (axes(R, 1), axes_domain...) | ||
return splitdims(Q, axes_Q), splitdims(R, axes_R) | ||
end | ||
|
||
""" | ||
svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; full=false, kwargs...) | ||
svd(A::AbstractArray, biperm::BlockedPermutation{2}; full=false, kwargs...) | ||
Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as | ||
a linear map from the domain to the codomain indices. These can be specified either via | ||
their labels, or directly through a `biperm`. | ||
""" | ||
function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) | ||
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) | ||
return svd(A, biperm; kwargs...) | ||
end | ||
function svd(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...) | ||
# tensor to matrix | ||
A_mat = fusedims(A, biperm) | ||
|
||
# factorization | ||
U, S, Vᴴ = full ? svd_full(A_mat; kwargs...) : svd_compact(A_mat; kwargs...) | ||
|
||
function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain) | ||
# TODO: Generalize to conversion to `Tuple` isn't needed. | ||
return qr( | ||
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)) | ||
) | ||
# matrix to tensor | ||
axes_codomain, axes_domain = blockpermute(axes(A), biperm) | ||
axes_U = (axes_codomain..., axes(U, 2)) | ||
axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...) | ||
return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ) | ||
end | ||
|
||
function svd(a::AbstractArray, biperm::BlockedPermutation{2}) | ||
a_matricized = fusedims(a, biperm) | ||
usv_matricized = LinearAlgebra.svd(a_matricized) | ||
u_matricized = usv_matricized.U | ||
s_diag = usv_matricized.S | ||
v_matricized = usv_matricized.Vt | ||
axes_codomain, axes_domain = blockpermute(axes(a), biperm) | ||
axes_u = (axes_codomain..., axes(u_matricized, 2)) | ||
axes_v = (axes(v_matricized, 1), axes_domain...) | ||
u = splitdims(u_matricized, axes_u) | ||
# TODO: Use `DiagonalArrays.diagonal` to make it more general. | ||
s = Diagonal(s_diag) | ||
v = splitdims(v_matricized, axes_v) | ||
return u, s, v | ||
# TODO: decide on interface | ||
""" | ||
tsvd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) | ||
tsvd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) | ||
Compute the truncated SVD decomposition of a generic N-dimensional array, by interpreting it | ||
as a linear map from the domain to the codomain indices. These can be specified either via | ||
their labels, or directly through a `biperm`. | ||
""" | ||
function tsvd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) | ||
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) | ||
return tsvd(A, biperm; kwargs...) | ||
end | ||
function tsvd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) | ||
# tensor to matrix | ||
A_mat = fusedims(A, biperm) | ||
|
||
# factorization | ||
U, S, Vᴴ = svd_trunc(A_mat; kwargs...) | ||
|
||
function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain) | ||
return svd( | ||
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)) | ||
) | ||
# matrix to tensor | ||
axes_codomain, axes_domain = blockpermute(axes(A), biperm) | ||
axes_U = (axes_codomain..., axes(U, 2)) | ||
axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...) | ||
return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
using Test: @test, @testset, @inferred | ||
using TestExtras: @constinferred | ||
using TensorAlgebra: contract, svd, tsvd, TensorAlgebra | ||
using TensorAlgebra.MatrixAlgebraKit: truncrank | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just as a style preference, I prefer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be honest, I added this as a placeholder while we settle on where the keyword arguments are supposed to be handled, since we should either import that into TensorAlgebra, or provide a different interface. I just wanted to run some tests in the meantime, but for sure will adapt this :) |
||
using LinearAlgebra: norm | ||
|
||
elts = (Float64, ComplexF64) | ||
|
||
@testset "Full SVD ($T)" for T in elts | ||
A = randn(T, 5, 4, 3, 2) | ||
A ./= norm(A) | ||
labels_A = (:a, :b, :c, :d) | ||
labels_U = (:b, :a) | ||
labels_Vᴴ = (:d, :c) | ||
|
||
Acopy = deepcopy(A) | ||
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=true) | ||
@test A == Acopy # should not have altered initial array | ||
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) | ||
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) | ||
@test A ≈ A′ | ||
@test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary | ||
@test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary | ||
end | ||
|
||
@testset "Compact SVD ($T)" for T in elts | ||
A = randn(T, 5, 4, 3, 2) | ||
A ./= norm(A) | ||
labels_A = (:a, :b, :c, :d) | ||
labels_U = (:b, :a) | ||
labels_Vᴴ = (:d, :c) | ||
|
||
Acopy = deepcopy(A) | ||
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=false) | ||
@test A == Acopy # should not have altered initial array | ||
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) | ||
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) | ||
@test A ≈ A′ | ||
k = min(size(S)...) | ||
@test size(U, 3) == k == size(Vᴴ, 1) | ||
end | ||
|
||
@testset "Truncated SVD ($T)" for T in elts | ||
A = randn(T, 5, 4, 3, 2) | ||
A ./= norm(A) | ||
labels_A = (:a, :b, :c, :d) | ||
labels_U = (:b, :a) | ||
labels_Vᴴ = (:d, :c) | ||
|
||
# test truncated SVD | ||
Acopy = deepcopy(A) | ||
_, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ) | ||
|
||
trunc = truncrank(size(S_untrunc, 1) - 1) | ||
U, S, Vᴴ = @constinferred tsvd(A, labels_A, labels_U, labels_Vᴴ; trunc) | ||
|
||
@test A == Acopy # should not have altered initial array | ||
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) | ||
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) | ||
@test norm(A - A′) ≈ S_untrunc[end] | ||
@test size(S, 1) == size(S_untrunc, 1) - 1 | ||
end |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a suggestion, but what if truncated SVD is implemented as an algorithm backend of a more general
svd
function (i.e. we havesvd(::Algorithm"truncated", a::AbstractArray, ...)
andsvd(::Algorithm"untruncated", a::AbstractArray, ...)
)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm definitely also fine with that, behind the scenes this is already what is going on in MatrixAlgebraKit as well, see the definition here. It's just a matter of deciding on where you want to start intercepting the user-input and what you would like the keywords for that to be, and I'm happy to map it to whatever MatrixAlgebraKit has
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that is definitely a bit subtle, since ITensor has its own syntax for truncation keyword arguments.
Maybe a strategy we can take for now is starting out by following whatever MatrixAlgebraKit.jl does (say
TensorAlgebra.svd
just forwards kwargs toMatrixAlgebraKit.svd
), and then at the ITensor level we can map our truncation arguments likecutoff
andmaxdim
to the TensorAlgebra.jl/MatrixAlgebraKit.jl names and conventions, and additionally chooseTruncatedAlgorithm
as the default SVD algorithm.