diff --git a/Project.toml b/Project.toml index f8700a8..a28b391 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" @@ -23,6 +24,7 @@ BlockArrays = "1.2.0" EllipsisNotation = "1.8.0" GradedUnitRanges = "0.1.0" LinearAlgebra = "1.10" +MatrixAlgebraKit = "0.1.0" TupleTools = "1.6.0" TypeParameterAccessors = "0.2.1" julia = "1.10" diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index caa2cc5..70a4772 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -7,11 +7,13 @@ include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") include("fusedims.jl") include("splitdims.jl") + include("contract/contract.jl") include("contract/output_labels.jl") include("contract/blockedperms.jl") include("contract/allocate_output.jl") include("contract/contract_matricize/contract.jl") + include("factorizations.jl") end diff --git a/src/factorizations.jl b/src/factorizations.jl index a017ca1..3869ebf 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -1,45 +1,74 @@ -using ArrayLayouts: LayoutMatrix -using LinearAlgebra: LinearAlgebra, Diagonal - -function qr(a::AbstractArray, biperm::BlockedPermutation{2}) - a_matricized = fusedims(a, biperm) - # TODO: Make this more generic, allow choosing thin or full, - # make sure this works on GPU. - q_fact, r_matricized = LinearAlgebra.qr(a_matricized) - q_matricized = typeof(a_matricized)(q_fact) - axes_codomain, axes_domain = blockpermute(axes(a), biperm) - axes_q = (axes_codomain..., axes(q_matricized, 2)) - axes_r = (axes(r_matricized, 1), axes_domain...) - q = splitdims(q_matricized, axes_q) - r = splitdims(r_matricized, axes_r) - return q, r +using MatrixAlgebraKit: qr_full, qr_compact, svd_full, svd_compact, svd_trunc + +# TODO: consider in-place version +# TODO: figure out kwargs and document +# +""" + qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; full=true, kwargs...) -> Q, R + qr(A::AbstractArray, biperm::BlockedPermutation{2}; full=true, kwargs...) -> Q, R + +Compute the QR decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. +""" +function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return qr(A, biperm) end +function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=true, kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + Q, R = full ? qr_full(A_mat; kwargs...) : qr_compact(A_mat; kwargs...) -function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain) - # TODO: Generalize to conversion to `Tuple` isn't needed. - return qr( - a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)) - ) + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_Q = (axes_codomain..., axes(Q, 2)) + axes_R = (axes(R, 1), axes_domain...) + return splitdims(Q, axes_Q), splitdims(R, axes_R) end -function svd(a::AbstractArray, biperm::BlockedPermutation{2}) - a_matricized = fusedims(a, biperm) - usv_matricized = LinearAlgebra.svd(a_matricized) - u_matricized = usv_matricized.U - s_diag = usv_matricized.S - v_matricized = usv_matricized.Vt - axes_codomain, axes_domain = blockpermute(axes(a), biperm) - axes_u = (axes_codomain..., axes(u_matricized, 2)) - axes_v = (axes(v_matricized, 1), axes_domain...) - u = splitdims(u_matricized, axes_u) - # TODO: Use `DiagonalArrays.diagonal` to make it more general. - s = Diagonal(s_diag) - v = splitdims(v_matricized, axes_v) - return u, s, v +# TODO: separate out the algorithm selection step from the implementation +""" + svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ + svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ + +Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments +- `full::Bool=false`: select between a "thick" or a "thin" decomposition, where both `U` and `Vᴴ` + are unitary or isometric. +- `trunc`: Truncation keywords for `svd_trunc`. Not compatible with `full=true`. +- Other keywords are passed on directly to MatrixAlgebraKit +""" +function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return svd(A, biperm; kwargs...) end +function svd( + A::AbstractArray, + biperm::BlockedPermutation{2}; + full::Bool=false, + trunc=nothing, + kwargs..., +) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + if !isnothing(trunc) + @assert !full "Specified both full and truncation, currently not supported" + U, S, Vᴴ = svd_trunc(A_mat; trunc, kwargs...) + else + U, S, Vᴴ = full ? svd_full(A_mat; kwargs...) : svd_compact(A_mat; kwargs...) + end -function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain) - return svd( - a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)) - ) + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_U = (axes_codomain..., axes(U, 2)) + axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...) + return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 221274b..b070e8a 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -212,26 +212,3 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_dest[] ≈ s[] * t[] end end -@testset "qr (eltype=$elt)" for elt in elts - a = randn(elt, 5, 4, 3, 2) - labels_a = (:a, :b, :c, :d) - labels_q = (:b, :a) - labels_r = (:d, :c) - q, r = qr(a, labels_a, labels_q, labels_r) - label_qr = :qr - a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)) - @test a ≈ a′ -end -@testset "svd (eltype=$elt)" for elt in elts - a = randn(elt, 5, 4, 3, 2) - labels_a = (:a, :b, :c, :d) - labels_u = (:b, :a) - labels_v = (:d, :c) - u, s, v = svd(a, labels_a, labels_u, labels_v) - label_u = :u - label_v = :v - # TODO: Define multi-arg `contract`? - us, labels_us = contract(u, (labels_u..., label_u), s, (label_u, label_v)) - a′ = contract(labels_a, us, labels_us, v, (label_v, labels_v...)) - @test a ≈ a′ -end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl new file mode 100644 index 0000000..5dc3e48 --- /dev/null +++ b/test/test_factorizations.jl @@ -0,0 +1,91 @@ +using Test: @test, @testset, @inferred +using TestExtras: @constinferred +using TensorAlgebra: contract, qr, svd, TensorAlgebra +using TensorAlgebra.MatrixAlgebraKit: truncrank + +elts = (Float64, ComplexF64) + +# QR Decomposition +# ---------------- +@testset "Full QR ($T)" for T in elts + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_Q = (:b, :a) + labels_R = (:d, :c) + + Acopy = deepcopy(A) + Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=true) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) + @test A ≈ A′ + @test size(Q, 1) * size(Q, 2) == size(Q, 3) # Q is unitary +end + +@testset "Compact QR ($T)" for T in elts + A = randn(T, 2, 3, 4, 5) # compact only makes a difference for less columns + labels_A = (:a, :b, :c, :d) + labels_Q = (:b, :a) + labels_R = (:d, :c) + + Acopy = deepcopy(A) + Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full=false) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) + @test A ≈ A′ + @test size(Q, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) +end + +# Singular Value Decomposition +# ---------------------------- +@testset "Full SVD ($T)" for T in elts + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_U = (:b, :a) + labels_Vᴴ = (:d, :c) + + Acopy = deepcopy(A) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=true) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test A ≈ A′ + @test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary + @test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary +end + +@testset "Compact SVD ($T)" for T in elts + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_U = (:b, :a) + labels_Vᴴ = (:d, :c) + + Acopy = deepcopy(A) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full=false) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test A ≈ A′ + k = min(size(S)...) + @test size(U, 3) == k == size(Vᴴ, 1) +end + +@testset "Truncated SVD ($T)" for T in elts + A = randn(T, 5, 4, 3, 2) + labels_A = (:a, :b, :c, :d) + labels_U = (:b, :a) + labels_Vᴴ = (:d, :c) + + # test truncated SVD + Acopy = deepcopy(A) + _, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ) + + trunc = truncrank(size(S_untrunc, 1) - 1) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; trunc) + + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) + @test norm(A - A′) ≈ S_untrunc[end] + @test size(S, 1) == size(S_untrunc, 1) - 1 +end +