Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

Expand All @@ -23,6 +24,7 @@ BlockArrays = "1.2.0"
EllipsisNotation = "1.8.0"
GradedUnitRanges = "0.1.0"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.1.0"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.2.1"
julia = "1.10"
2 changes: 2 additions & 0 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("fusedims.jl")
include("splitdims.jl")

include("contract/contract.jl")
include("contract/output_labels.jl")
include("contract/blockedperms.jl")
include("contract/allocate_output.jl")
include("contract/contract_matricize/contract.jl")

include("factorizations.jl")

end
114 changes: 76 additions & 38 deletions src/factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,83 @@
using ArrayLayouts: LayoutMatrix
using LinearAlgebra: LinearAlgebra, Diagonal

function qr(a::AbstractArray, biperm::BlockedPermutation{2})
a_matricized = fusedims(a, biperm)
# TODO: Make this more generic, allow choosing thin or full,
# make sure this works on GPU.
q_fact, r_matricized = LinearAlgebra.qr(a_matricized)
q_matricized = typeof(a_matricized)(q_fact)
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
axes_q = (axes_codomain..., axes(q_matricized, 2))
axes_r = (axes(r_matricized, 1), axes_domain...)
q = splitdims(q_matricized, axes_q)
r = splitdims(r_matricized, axes_r)
return q, r
using MatrixAlgebraKit:
MatrixAlgebraKit, qr_full, qr_compact, svd_full, svd_compact, svd_trunc
# TODO: consider in-place version
# TODO: figure out kwargs and document
#
"""
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; full=true, kwargs...)
qr(A::AbstractArray, biperm::BlockedPermutation{2}; full=true, kwargs...)
Compute the QR decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.
"""
function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return qr(A, biperm)
end
function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=true, kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
Q, R = full ? qr_full(A_mat; kwargs...) : qr_compact(A_mat; kwargs...)

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_Q = (axes_codomain..., axes(Q, 2))
axes_R = (axes(R, 1), axes_domain...)
return splitdims(Q, axes_Q), splitdims(R, axes_R)
end

"""
svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; full=false, kwargs...)
svd(A::AbstractArray, biperm::BlockedPermutation{2}; full=false, kwargs...)
Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.
"""
function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return svd(A, biperm; kwargs...)
end
function svd(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
U, S, Vᴴ = full ? svd_full(A_mat; kwargs...) : svd_compact(A_mat; kwargs...)

function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain)
# TODO: Generalize to conversion to `Tuple` isn't needed.
return qr(
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
)
# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_U = (axes_codomain..., axes(U, 2))
axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...)
return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ)
end

function svd(a::AbstractArray, biperm::BlockedPermutation{2})
a_matricized = fusedims(a, biperm)
usv_matricized = LinearAlgebra.svd(a_matricized)
u_matricized = usv_matricized.U
s_diag = usv_matricized.S
v_matricized = usv_matricized.Vt
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
axes_u = (axes_codomain..., axes(u_matricized, 2))
axes_v = (axes(v_matricized, 1), axes_domain...)
u = splitdims(u_matricized, axes_u)
# TODO: Use `DiagonalArrays.diagonal` to make it more general.
s = Diagonal(s_diag)
v = splitdims(v_matricized, axes_v)
return u, s, v
# TODO: decide on interface
"""
tsvd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
tsvd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
Compute the truncated SVD decomposition of a generic N-dimensional array, by interpreting it
as a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.
"""
function tsvd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
Copy link
Member

@mtfishman mtfishman Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a suggestion, but what if truncated SVD is implemented as an algorithm backend of a more general svd function (i.e. we have svd(::Algorithm"truncated", a::AbstractArray, ...) and svd(::Algorithm"untruncated", a::AbstractArray, ...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm definitely also fine with that, behind the scenes this is already what is going on in MatrixAlgebraKit as well, see the definition here. It's just a matter of deciding on where you want to start intercepting the user-input and what you would like the keywords for that to be, and I'm happy to map it to whatever MatrixAlgebraKit has

Copy link
Member

@mtfishman mtfishman Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that is definitely a bit subtle, since ITensor has its own syntax for truncation keyword arguments.

Maybe a strategy we can take for now is starting out by following whatever MatrixAlgebraKit.jl does (say TensorAlgebra.svd just forwards kwargs to MatrixAlgebraKit.svd), and then at the ITensor level we can map our truncation arguments like cutoff and maxdim to the TensorAlgebra.jl/MatrixAlgebraKit.jl names and conventions, and additionally choose TruncatedAlgorithm as the default SVD algorithm.

biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return tsvd(A, biperm; kwargs...)
end
function tsvd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
U, S, Vᴴ = svd_trunc(A_mat; kwargs...)

function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain)
return svd(
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
)
# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_U = (axes_codomain..., axes(U, 2))
axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...)
return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ)
end
62 changes: 62 additions & 0 deletions test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Test: @test, @testset, @inferred
using TestExtras: @constinferred
using TensorAlgebra: contract, svd, tsvd, TensorAlgebra
using TensorAlgebra.MatrixAlgebraKit: truncrank
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a style preference, I prefer using MatrixAlgebraKit: truncrank. using TensorAlgebra.MatrixAlgebraKit implicitly assumes the dependency is part of the API of the package, which I think isn't a good habit to get into and also just gets confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I added this as a placeholder while we settle on where the keyword arguments are supposed to be handled, since we should either import that into TensorAlgebra, or provide a different interface. I just wanted to run some tests in the meantime, but for sure will adapt this :)

using LinearAlgebra: norm

elts = (Float64, ComplexF64)

@testset "Full SVD ($T)" for T in elts
A = randn(T, 5, 4, 3, 2)
A ./= norm(A)
labels_A = (:a, :b, :c, :d)
labels_U = (:b, :a)
labels_Vᴴ = (:d, :c)

Acopy = deepcopy(A)
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=true)
@test A == Acopy # should not have altered initial array
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
@test A A′
@test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary
@test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary
end

@testset "Compact SVD ($T)" for T in elts
A = randn(T, 5, 4, 3, 2)
A ./= norm(A)
labels_A = (:a, :b, :c, :d)
labels_U = (:b, :a)
labels_Vᴴ = (:d, :c)

Acopy = deepcopy(A)
U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=false)
@test A == Acopy # should not have altered initial array
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
@test A A′
k = min(size(S)...)
@test size(U, 3) == k == size(Vᴴ, 1)
end

@testset "Truncated SVD ($T)" for T in elts
A = randn(T, 5, 4, 3, 2)
A ./= norm(A)
labels_A = (:a, :b, :c, :d)
labels_U = (:b, :a)
labels_Vᴴ = (:d, :c)

# test truncated SVD
Acopy = deepcopy(A)
_, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ)

trunc = truncrank(size(S_untrunc, 1) - 1)
U, S, Vᴴ = @constinferred tsvd(A, labels_A, labels_U, labels_Vᴴ; trunc)

@test A == Acopy # should not have altered initial array
US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v))
A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...))
@test norm(A - A′) S_untrunc[end]
@test size(S, 1) == size(S_untrunc, 1) - 1
end
Loading