Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

Expand All @@ -23,6 +24,7 @@ BlockArrays = "1.2.0"
EllipsisNotation = "1.8.0"
GradedUnitRanges = "0.1.0"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.1.0"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.2.1"
julia = "1.10"
2 changes: 2 additions & 0 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("fusedims.jl")
include("splitdims.jl")

include("contract/contract.jl")
include("contract/output_labels.jl")
include("contract/blockedperms.jl")
include("contract/allocate_output.jl")
include("contract/contract_matricize/contract.jl")

include("factorizations.jl")

end
105 changes: 67 additions & 38 deletions src/factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,74 @@
using ArrayLayouts: LayoutMatrix
using LinearAlgebra: LinearAlgebra, Diagonal

function qr(a::AbstractArray, biperm::BlockedPermutation{2})
a_matricized = fusedims(a, biperm)
# TODO: Make this more generic, allow choosing thin or full,
# make sure this works on GPU.
q_fact, r_matricized = LinearAlgebra.qr(a_matricized)
q_matricized = typeof(a_matricized)(q_fact)
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
axes_q = (axes_codomain..., axes(q_matricized, 2))
axes_r = (axes(r_matricized, 1), axes_domain...)
q = splitdims(q_matricized, axes_q)
r = splitdims(r_matricized, axes_r)
return q, r
using MatrixAlgebraKit: qr_full, qr_compact, svd_full, svd_compact, svd_trunc

# TODO: consider in-place version
# TODO: figure out kwargs and document
#
"""
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; full=true, kwargs...) -> Q, R
qr(A::AbstractArray, biperm::BlockedPermutation{2}; full=true, kwargs...) -> Q, R

Compute the QR decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.
"""
function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return qr(A, biperm)
end
function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=true, kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
Q, R = full ? qr_full(A_mat; kwargs...) : qr_compact(A_mat; kwargs...)

function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain)
# TODO: Generalize to conversion to `Tuple` isn't needed.
return qr(
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
)
# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_Q = (axes_codomain..., axes(Q, 2))
axes_R = (axes(R, 1), axes_domain...)
return splitdims(Q, axes_Q), splitdims(R, axes_R)
end

function svd(a::AbstractArray, biperm::BlockedPermutation{2})
a_matricized = fusedims(a, biperm)
usv_matricized = LinearAlgebra.svd(a_matricized)
u_matricized = usv_matricized.U
s_diag = usv_matricized.S
v_matricized = usv_matricized.Vt
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
axes_u = (axes_codomain..., axes(u_matricized, 2))
axes_v = (axes(v_matricized, 1), axes_domain...)
u = splitdims(u_matricized, axes_u)
# TODO: Use `DiagonalArrays.diagonal` to make it more general.
s = Diagonal(s_diag)
v = splitdims(v_matricized, axes_v)
return u, s, v
# TODO: separate out the algorithm selection step from the implementation
"""
svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ
svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ

Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.

## Keyword arguments
- `full::Bool=false`: select between a "thick" or a "thin" decomposition, where both `U` and `Vᴴ`
are unitary or isometric.
- `trunc`: Truncation keywords for `svd_trunc`. Not compatible with `full=true`.
- Other keywords are passed on directly to MatrixAlgebraKit
"""
function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return svd(A, biperm; kwargs...)
end
function svd(
A::AbstractArray,
biperm::BlockedPermutation{2};
full::Bool=false,
trunc=nothing,
kwargs...,
)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
if !isnothing(trunc)
@assert !full "Specified both full and truncation, currently not supported"
U, S, Vᴴ = svd_trunc(A_mat; trunc, kwargs...)
else
U, S, Vᴴ = full ? svd_full(A_mat; kwargs...) : svd_compact(A_mat; kwargs...)
end

function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain)
return svd(
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
)
# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_U = (axes_codomain..., axes(U, 2))
axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...)
return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ)
end
23 changes: 0 additions & 23 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,26 +212,3 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a_dest[] ≈ s[] * t[]
end
end
@testset "qr (eltype=$elt)" for elt in elts
a = randn(elt, 5, 4, 3, 2)
labels_a = (:a, :b, :c, :d)
labels_q = (:b, :a)
labels_r = (:d, :c)
q, r = qr(a, labels_a, labels_q, labels_r)
label_qr = :qr
a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...))
@test a ≈ a′
end
@testset "svd (eltype=$elt)" for elt in elts
a = randn(elt, 5, 4, 3, 2)
labels_a = (:a, :b, :c, :d)
labels_u = (:b, :a)
labels_v = (:d, :c)
u, s, v = svd(a, labels_a, labels_u, labels_v)
label_u = :u
label_v = :v
# TODO: Define multi-arg `contract`?
us, labels_us = contract(u, (labels_u..., label_u), s, (label_u, label_v))
a′ = contract(labels_a, us, labels_us, v, (label_v, labels_v...))
@test a ≈ a′
end
91 changes: 91 additions & 0 deletions test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using Test: @test, @testset, @inferred
using TestExtras: @constinferred
using TensorAlgebra: contract, qr, svd, TensorAlgebra
using TensorAlgebra.MatrixAlgebraKit: truncrank
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a style preference, I prefer using MatrixAlgebraKit: truncrank. using TensorAlgebra.MatrixAlgebraKit implicitly assumes the dependency is part of the API of the package, which I think isn't a good habit to get into and also just gets confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I added this as a placeholder while we settle on where the keyword arguments are supposed to be handled, since we should either import that into TensorAlgebra, or provide a different interface. I just wanted to run some tests in the meantime, but for sure will adapt this :)


elts = (Float64, ComplexF64)

# QR Decomposition
# ----------------
@testset "Full QR ($T)" for T in elts
A = randn(T, 5, 4, 3, 2)
labels_A = (:a, :b, :c, :d)
labels_Q = (:b, :a)
labels_R = (:d, :c)

Acopy = deepcopy(A)
Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=true)
@test A == Acopy # should not have altered initial array
A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...))
@test A ≈ A′
@test size(Q, 1) * size(Q, 2) == size(Q, 3) # Q is unitary
end

@testset "Compact QR ($T)" for T in elts
A = randn(T, 2, 3, 4, 5) # compact only makes a difference for less columns
labels_A = (:a, :b, :c, :d)
labels_Q = (:b, :a)
labels_R = (:d, :c)

Acopy = deepcopy(A)
Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=false)
@test A == Acopy # should not have altered initial array
A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...))
@test A ≈ A′
@test size(Q, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
end

# Singular Value Decomposition
# ----------------------------
@testset "Full SVD ($T)" for T in elts
A = randn(T, 5, 4, 3, 2)
labels_A = (:a, :b, :c, :d)
labels_U = (:b, :a)
labels_Vᴴ = (:d, :c)

Acopy = deepcopy(A)
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=true)
@test A == Acopy # should not have altered initial array
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
@test A ≈ A′
@test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary
@test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary
end

@testset "Compact SVD ($T)" for T in elts
A = randn(T, 5, 4, 3, 2)
labels_A = (:a, :b, :c, :d)
labels_U = (:b, :a)
labels_Vᴴ = (:d, :c)

Acopy = deepcopy(A)
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=false)
@test A == Acopy # should not have altered initial array
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
@test A ≈ A′
k = min(size(S)...)
@test size(U, 3) == k == size(Vᴴ, 1)
end

@testset "Truncated SVD ($T)" for T in elts
A = randn(T, 5, 4, 3, 2)
labels_A = (:a, :b, :c, :d)
labels_U = (:b, :a)
labels_Vᴴ = (:d, :c)

# test truncated SVD
Acopy = deepcopy(A)
_, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ)

trunc = truncrank(size(S_untrunc, 1) - 1)
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; trunc)

@test A == Acopy # should not have altered initial array
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
@test norm(A - A′) ≈ S_untrunc[end]
@test size(S, 1) == size(S_untrunc, 1) - 1
end

Loading