-
Notifications
You must be signed in to change notification settings - Fork 2
[WIP] Add tensor factorizations through MatrixAlgebraKit #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9561b1d
Migrate factorizations to their own folder
lkdvos 6aeaa5b
Add MatrixAlgebraKit dependency
lkdvos dc27256
Changed my mind and implemented in factorizations.jl
lkdvos 7428023
Add truncated svd
lkdvos 7cb6538
Fix typos
lkdvos 1deecbf
Add tests svd
lkdvos 14ad07f
Merge `svd` implementations in a single function
lkdvos 125577a
remove unnecessary norm
lkdvos b550e24
Add qr tests
lkdvos 5e4213c
Remove test duplicates
lkdvos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,74 @@ | ||
using ArrayLayouts: LayoutMatrix | ||
using LinearAlgebra: LinearAlgebra, Diagonal | ||
|
||
function qr(a::AbstractArray, biperm::BlockedPermutation{2}) | ||
a_matricized = fusedims(a, biperm) | ||
# TODO: Make this more generic, allow choosing thin or full, | ||
# make sure this works on GPU. | ||
q_fact, r_matricized = LinearAlgebra.qr(a_matricized) | ||
q_matricized = typeof(a_matricized)(q_fact) | ||
axes_codomain, axes_domain = blockpermute(axes(a), biperm) | ||
axes_q = (axes_codomain..., axes(q_matricized, 2)) | ||
axes_r = (axes(r_matricized, 1), axes_domain...) | ||
q = splitdims(q_matricized, axes_q) | ||
r = splitdims(r_matricized, axes_r) | ||
return q, r | ||
using MatrixAlgebraKit: qr_full, qr_compact, svd_full, svd_compact, svd_trunc | ||
|
||
# TODO: consider in-place version | ||
# TODO: figure out kwargs and document | ||
# | ||
""" | ||
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; full=true, kwargs...) -> Q, R | ||
qr(A::AbstractArray, biperm::BlockedPermutation{2}; full=true, kwargs...) -> Q, R | ||
|
||
Compute the QR decomposition of a generic N-dimensional array, by interpreting it as | ||
a linear map from the domain to the codomain indices. These can be specified either via | ||
their labels, or directly through a `biperm`. | ||
""" | ||
function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) | ||
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) | ||
return qr(A, biperm) | ||
end | ||
function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=true, kwargs...) | ||
# tensor to matrix | ||
A_mat = fusedims(A, biperm) | ||
|
||
# factorization | ||
Q, R = full ? qr_full(A_mat; kwargs...) : qr_compact(A_mat; kwargs...) | ||
|
||
function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain) | ||
# TODO: Generalize to conversion to `Tuple` isn't needed. | ||
return qr( | ||
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)) | ||
) | ||
# matrix to tensor | ||
axes_codomain, axes_domain = blockpermute(axes(A), biperm) | ||
axes_Q = (axes_codomain..., axes(Q, 2)) | ||
axes_R = (axes(R, 1), axes_domain...) | ||
return splitdims(Q, axes_Q), splitdims(R, axes_R) | ||
end | ||
|
||
function svd(a::AbstractArray, biperm::BlockedPermutation{2}) | ||
a_matricized = fusedims(a, biperm) | ||
usv_matricized = LinearAlgebra.svd(a_matricized) | ||
u_matricized = usv_matricized.U | ||
s_diag = usv_matricized.S | ||
v_matricized = usv_matricized.Vt | ||
axes_codomain, axes_domain = blockpermute(axes(a), biperm) | ||
axes_u = (axes_codomain..., axes(u_matricized, 2)) | ||
axes_v = (axes(v_matricized, 1), axes_domain...) | ||
u = splitdims(u_matricized, axes_u) | ||
# TODO: Use `DiagonalArrays.diagonal` to make it more general. | ||
s = Diagonal(s_diag) | ||
v = splitdims(v_matricized, axes_v) | ||
return u, s, v | ||
# TODO: separate out the algorithm selection step from the implementation | ||
""" | ||
svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ | ||
svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ | ||
|
||
Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as | ||
a linear map from the domain to the codomain indices. These can be specified either via | ||
their labels, or directly through a `biperm`. | ||
|
||
## Keyword arguments | ||
- `full::Bool=false`: select between a "thick" or a "thin" decomposition, where both `U` and `Vᴴ` | ||
are unitary or isometric. | ||
- `trunc`: Truncation keywords for `svd_trunc`. Not compatible with `full=true`. | ||
- Other keywords are passed on directly to MatrixAlgebraKit | ||
""" | ||
function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) | ||
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) | ||
return svd(A, biperm; kwargs...) | ||
end | ||
function svd( | ||
A::AbstractArray, | ||
biperm::BlockedPermutation{2}; | ||
full::Bool=false, | ||
trunc=nothing, | ||
kwargs..., | ||
) | ||
# tensor to matrix | ||
A_mat = fusedims(A, biperm) | ||
|
||
# factorization | ||
if !isnothing(trunc) | ||
@assert !full "Specified both full and truncation, currently not supported" | ||
U, S, Vᴴ = svd_trunc(A_mat; trunc, kwargs...) | ||
else | ||
U, S, Vᴴ = full ? svd_full(A_mat; kwargs...) : svd_compact(A_mat; kwargs...) | ||
end | ||
|
||
function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain) | ||
return svd( | ||
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)) | ||
) | ||
# matrix to tensor | ||
axes_codomain, axes_domain = blockpermute(axes(A), biperm) | ||
axes_U = (axes_codomain..., axes(U, 2)) | ||
axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...) | ||
return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
using Test: @test, @testset, @inferred | ||
using TestExtras: @constinferred | ||
using TensorAlgebra: contract, qr, svd, TensorAlgebra | ||
using TensorAlgebra.MatrixAlgebraKit: truncrank | ||
|
||
elts = (Float64, ComplexF64) | ||
|
||
# QR Decomposition | ||
# ---------------- | ||
@testset "Full QR ($T)" for T in elts | ||
A = randn(T, 5, 4, 3, 2) | ||
labels_A = (:a, :b, :c, :d) | ||
labels_Q = (:b, :a) | ||
labels_R = (:d, :c) | ||
|
||
Acopy = deepcopy(A) | ||
Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=true) | ||
@test A == Acopy # should not have altered initial array | ||
A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) | ||
@test A ≈ A′ | ||
@test size(Q, 1) * size(Q, 2) == size(Q, 3) # Q is unitary | ||
end | ||
|
||
@testset "Compact QR ($T)" for T in elts | ||
A = randn(T, 2, 3, 4, 5) # compact only makes a difference for less columns | ||
labels_A = (:a, :b, :c, :d) | ||
labels_Q = (:b, :a) | ||
labels_R = (:d, :c) | ||
|
||
Acopy = deepcopy(A) | ||
Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=false) | ||
@test A == Acopy # should not have altered initial array | ||
A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) | ||
@test A ≈ A′ | ||
@test size(Q, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) | ||
end | ||
|
||
# Singular Value Decomposition | ||
# ---------------------------- | ||
@testset "Full SVD ($T)" for T in elts | ||
A = randn(T, 5, 4, 3, 2) | ||
labels_A = (:a, :b, :c, :d) | ||
labels_U = (:b, :a) | ||
labels_Vᴴ = (:d, :c) | ||
|
||
Acopy = deepcopy(A) | ||
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=true) | ||
@test A == Acopy # should not have altered initial array | ||
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) | ||
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) | ||
@test A ≈ A′ | ||
@test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary | ||
@test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary | ||
end | ||
|
||
@testset "Compact SVD ($T)" for T in elts | ||
A = randn(T, 5, 4, 3, 2) | ||
labels_A = (:a, :b, :c, :d) | ||
labels_U = (:b, :a) | ||
labels_Vᴴ = (:d, :c) | ||
|
||
Acopy = deepcopy(A) | ||
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=false) | ||
@test A == Acopy # should not have altered initial array | ||
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) | ||
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) | ||
@test A ≈ A′ | ||
k = min(size(S)...) | ||
@test size(U, 3) == k == size(Vᴴ, 1) | ||
end | ||
|
||
@testset "Truncated SVD ($T)" for T in elts | ||
A = randn(T, 5, 4, 3, 2) | ||
labels_A = (:a, :b, :c, :d) | ||
labels_U = (:b, :a) | ||
labels_Vᴴ = (:d, :c) | ||
|
||
# test truncated SVD | ||
Acopy = deepcopy(A) | ||
_, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ) | ||
|
||
trunc = truncrank(size(S_untrunc, 1) - 1) | ||
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; trunc) | ||
|
||
@test A == Acopy # should not have altered initial array | ||
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) | ||
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) | ||
@test norm(A - A′) ≈ S_untrunc[end] | ||
@test size(S, 1) == size(S_untrunc, 1) - 1 | ||
end | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just as a style preference, I prefer
using MatrixAlgebraKit: truncrank
.using TensorAlgebra.MatrixAlgebraKit
implicitly assumes the dependency is part of the API of the package, which I think isn't a good habit to get into and also just gets confusing.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I added this as a placeholder while we settle on where the keyword arguments are supposed to be handled, since we should either import that into TensorAlgebra, or provide a different interface. I just wanted to run some tests in the meantime, but for sure will adapt this :)