diff --git a/Project.toml b/Project.toml index cf887d5..9fa42f7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.2.9" +version = "0.2.10" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl new file mode 100644 index 0000000..a3796b9 --- /dev/null +++ b/src/MatrixAlgebra.jl @@ -0,0 +1,136 @@ +module MatrixAlgebra + +export eigen, + eigen!, + eigvals, + eigvals!, + factorize, + factorize!, + lq, + lq!, + orth, + orth!, + polar, + polar!, + qr, + qr!, + svd, + svd!, + svdvals, + svdvals! + +using LinearAlgebra: LinearAlgebra +using MatrixAlgebraKit + +for (f, f_full, f_compact) in ( + (:qr, :qr_full, :qr_compact), + (:qr!, :qr_full!, :qr_compact!), + (:lq, :lq_full, :lq_compact), + (:lq!, :lq_full!, :lq_compact!), +) + @eval begin + function $f(A::AbstractMatrix; full::Bool=false, kwargs...) + f = full ? $f_full : $f_compact + return f(A; kwargs...) + end + end +end + +for (eigen, eigh_full, eig_full, eigh_trunc, eig_trunc) in ( + (:eigen, :eigh_full, :eig_full, :eigh_trunc, :eig_trunc), + (:eigen!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!), +) + @eval begin + function $eigen(A::AbstractMatrix; trunc=nothing, ishermitian=nothing, kwargs...) + ishermitian = @something ishermitian LinearAlgebra.ishermitian(A) + f = if !isnothing(trunc) + ishermitian ? $eigh_trunc : $eig_trunc + else + ishermitian ? $eigh_full : $eig_full + end + return f(A; kwargs...) + end + end +end + +for (eigvals, eigh_vals, eig_vals) in + ((:eigvals, :eigh_vals, :eig_vals), (:eigvals!, :eigh_vals!, :eig_vals!)) + @eval begin + function $eigvals(A::AbstractMatrix; ishermitian=nothing, kwargs...) + ishermitian = @something ishermitian LinearAlgebra.ishermitian(A) + f = (ishermitian ? $eigh_vals : $eig_vals) + return f(A; kwargs...) + end + end +end + +for (svd, svd_trunc, svd_full, svd_compact) in ( + (:svd, :svd_trunc, :svd_full, :svd_compact), + (:svd!, :svd_trunc!, :svd_full!, :svd_compact!), +) + @eval begin + function $svd(A::AbstractMatrix; full::Bool=false, trunc=nothing, kwargs...) + return if !isnothing(trunc) + @assert !full "Specified both full and truncation, currently not supported" + $svd_trunc(A; trunc, kwargs...) + else + (full ? $svd_full : $svd_compact)(A; kwargs...) + end + end + end +end + +for (svdvals, svd_vals) in ((:svdvals, :svd_vals), (:svdvals!, :svd_vals!)) + @eval begin + function $svdvals(A::AbstractMatrix; ishermitian=nothing, kwargs...) + return $svd_vals(A; kwargs...) + end + end +end + +for (polar, left_polar, right_polar) in + ((:polar, :left_polar, :right_polar), (:polar!, :left_polar!, :right_polar!)) + @eval begin + function $polar(A::AbstractMatrix; side=:left, kwargs...) + f = if side == :left + $left_polar + elseif side == :right + $right_polar + else + throw(ArgumentError("`side=$side` not supported.")) + end + return f(A; kwargs...) + end + end +end + +for (orth, left_orth, right_orth) in + ((:orth, :left_orth, :right_orth), (:orth!, :left_orth!, :right_orth!)) + @eval begin + function $orth(A::AbstractMatrix; side=:left, kwargs...) + f = if side == :left + $left_orth + elseif side == :right + $right_orth + else + throw(ArgumentError("`side=$side` not supported.")) + end + return f(A; kwargs...) + end + end +end + +for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, :orth!)) + @eval begin + function $factorize(A::AbstractMatrix; orth=:left, kwargs...) + f = if orth in (:left, :right) + $orth_f + else + throw(ArgumentError("`orth=$orth` not supported.")) + end + return f(A; side=orth, kwargs...) + end + end +end + +end diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 591133b..1bc23e4 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,7 +1,24 @@ module TensorAlgebra -export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, svdvals +export contract, + contract!, + eigen, + eigvals, + factorize, + left_null, + left_orth, + left_polar, + lq, + qr, + right_null, + right_orth, + right_polar, + orth, + polar, + svd, + svdvals +include("MatrixAlgebra.jl") include("blockedtuple.jl") include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") diff --git a/src/factorizations.jl b/src/factorizations.jl index 31aac73..a2fee2d 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -1,25 +1,29 @@ using LinearAlgebra: LinearAlgebra -using MatrixAlgebraKit: - eig_full!, - eig_trunc!, - eig_vals!, - eigh_full!, - eigh_trunc!, - eigh_vals!, - left_null!, - left_orth!, - left_polar!, - lq_full!, - lq_compact!, - qr_full!, - qr_compact!, - right_null!, - right_orth!, - right_polar!, - svd_full!, - svd_compact!, - svd_trunc!, - svd_vals! +using MatrixAlgebraKit: MatrixAlgebraKit + +for f in ( + :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize +) + @eval begin + function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return $f(A, biperm; kwargs...) + end + function $f(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + X, Y = MatrixAlgebra.$f(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_X = (axes_codomain..., axes(X, 2)) + axes_Y = (axes(Y, 1), axes_domain...) + return splitdims(X, axes_X), splitdims(Y, axes_Y) + end + end +end """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R @@ -37,23 +41,7 @@ their labels, or directly through a `biperm`. See also `MatrixAlgebraKit.qr_full!` and `MatrixAlgebraKit.qr_compact!`. """ -function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return qr(A, biperm; kwargs...) -end -function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - Q, R = full ? qr_full!(A_mat; kwargs...) : qr_compact!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_Q = (axes_codomain..., axes(Q, 2)) - axes_R = (axes(R, 1), axes_domain...) - return splitdims(Q, axes_Q), splitdims(R, axes_R) -end +qr """ lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q @@ -71,23 +59,87 @@ their labels, or directly through a `biperm`. See also `MatrixAlgebraKit.lq_full!` and `MatrixAlgebraKit.lq_compact!`. """ -function lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return lq(A, biperm; kwargs...) -end -function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) +lq - # factorization - L, Q = (full ? lq_full! : lq_compact!)(A_mat; kwargs...) +""" + left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P + left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_L = (axes_codomain..., axes(L, ndims(L))) - axes_Q = (axes(Q, 1), axes_domain...) - return splitdims(L, axes_L), splitdims(Q, axes_Q) -end +Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.left_polar!`. +""" +left_polar + +""" + right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W + right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W + +Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.right_polar!`. +""" +right_polar + +""" + left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C + left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C + +Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.left_orth!`. +""" +left_orth + +""" + right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V + right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V + +Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.right_orth!`. +""" +right_orth + +""" + factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y + factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y + +Compute the decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- `orth::Symbol=:left`: specify the orthogonality of the decomposition. + Currently only `:left` and `:right` are supported. +- Other keywords are passed on directly to MatrixAlgebraKit. +""" +factorize """ eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V @@ -111,25 +163,12 @@ function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwarg biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return eigen(A, biperm; kwargs...) end -function eigen( - A::AbstractArray, - biperm::BlockedPermutation{2}; - trunc=nothing, - ishermitian=nothing, - kwargs..., -) +function eigen(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) # tensor to matrix A_mat = fusedims(A, biperm) - ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat) - # factorization - f! = if !isnothing(trunc) - ishermitian ? eigh_trunc! : eig_trunc! - else - ishermitian ? eigh_full! : eig_full! - end - D, V = f!(A_mat; kwargs...) + D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) # matrix to tensor axes_codomain, = blockpermute(axes(A), biperm) @@ -157,15 +196,11 @@ function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwa biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return eigvals(A, biperm; kwargs...) end -function eigvals( - A::AbstractArray, biperm::BlockedPermutation{2}; ishermitian=nothing, kwargs... -) +function eigvals(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) - ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat) - return (ishermitian ? eigh_vals! : eig_vals!)(A_mat; kwargs...) + return MatrixAlgebra.eigvals!(A_mat; kwargs...) end -# TODO: separate out the algorithm selection step from the implementation """ svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ @@ -187,23 +222,12 @@ function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs. biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return svd(A, biperm; kwargs...) end -function svd( - A::AbstractArray, - biperm::BlockedPermutation{2}; - full::Bool=false, - trunc=nothing, - kwargs..., -) +function svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) # tensor to matrix A_mat = fusedims(A, biperm) # factorization - if !isnothing(trunc) - @assert !full "Specified both full and truncation, currently not supported" - U, S, Vᴴ = svd_trunc!(A_mat; trunc, kwargs...) - else - U, S, Vᴴ = full ? svd_full!(A_mat; kwargs...) : svd_compact!(A_mat; kwargs...) - end + U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -228,7 +252,7 @@ function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) end function svdvals(A::AbstractArray, biperm::BlockedPermutation{2}) A_mat = fusedims(A, biperm) - return svd_vals!(A_mat) + return MatrixAlgebra.svdvals!(A_mat) end """ @@ -254,7 +278,7 @@ function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; k end function left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) - N = left_null!(A_mat; kwargs...) + N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) axes_codomain, _ = blockpermute(axes(A), biperm) axes_N = (axes_codomain..., axes(N, 2)) N_tensor = splitdims(N, axes_N) @@ -284,168 +308,8 @@ function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; end function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) - Nᴴ = right_null!(A_mat; kwargs...) + Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) _, axes_domain = blockpermute(axes(A), biperm) axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...) return splitdims(Nᴴ, axes_Nᴴ) end - -""" - left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P - left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P - -Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- Keyword arguments are passed on directly to MatrixAlgebraKit. - -See also `MatrixAlgebraKit.left_polar!`. -""" -function left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return left_polar(A, biperm; kwargs...) -end -function left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - W, P = left_polar!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_W = (axes_codomain..., axes(W, 2)) - axes_P = (axes(P, 1), axes_domain...) - return splitdims(W, axes_W), splitdims(P, axes_P) -end - -""" - right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W - right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W - -Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- Keyword arguments are passed on directly to MatrixAlgebraKit. - -See also `MatrixAlgebraKit.right_polar!`. -""" -function right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return right_polar(A, biperm; kwargs...) -end -function right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - P, W = right_polar!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_P = (axes_codomain..., axes(P, ndims(P))) - axes_W = (axes(W, 1), axes_domain...) - return splitdims(P, axes_P), splitdims(W, axes_W) -end - -""" - left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C - left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C - -Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- Keyword arguments are passed on directly to MatrixAlgebraKit. - -See also `MatrixAlgebraKit.left_orth!`. -""" -function left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return left_orth(A, biperm; kwargs...) -end -function left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - V, C = left_orth!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_V = (axes_codomain..., axes(V, 2)) - axes_C = (axes(C, 1), axes_domain...) - return splitdims(V, axes_V), splitdims(C, axes_C) -end - -""" - right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V - right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V - -Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- Keyword arguments are passed on directly to MatrixAlgebraKit. - -See also `MatrixAlgebraKit.right_orth!`. -""" -function right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return right_orth(A, biperm; kwargs...) -end -function right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - P, W = right_orth!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_P = (axes_codomain..., axes(P, ndims(P))) - axes_W = (axes(W, 1), axes_domain...) - return splitdims(P, axes_P), splitdims(W, axes_W) -end - -""" - factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y - factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y - -Compute the decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- `orth::Symbol=:left`: specify the orthogonality of the decomposition. - Currently only `:left` and `:right` are supported. -- Other keywords are passed on directly to MatrixAlgebraKit. -""" -function factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return factorize(A, biperm; kwargs...) -end -function factorize(A::AbstractArray, biperm::BlockedPermutation{2}; orth=:left, kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - X, Y = (orth == :left ? left_orth! : right_orth!)(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_X = (axes_codomain..., axes(X, ndims(X))) - axes_Y = (axes(Y, 1), axes_domain...) - return splitdims(X, axes_X), splitdims(Y, axes_Y) -end diff --git a/test/test_exports.jl b/test/test_exports.jl index dbef64f..8dba1bf 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -7,12 +7,42 @@ using Test: @test, @testset :contract!, :eigen, :eigvals, + :factorize, :left_null, + :left_orth, + :left_polar, :lq, + :orth, + :polar, :qr, :right_null, + :right_orth, + :right_polar, :svd, :svdvals, ] @test issetequal(names(TensorAlgebra), exports) + + exports = [ + :MatrixAlgebra, + :eigen, + :eigen!, + :eigvals, + :eigvals!, + :factorize, + :factorize!, + :lq, + :lq!, + :orth, + :orth!, + :polar, + :polar!, + :qr, + :qr!, + :svd, + :svd!, + :svdvals, + :svdvals!, + ] + @test issetequal(names(TensorAlgebra.MatrixAlgebra), exports) end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 62b1561..0c5ccdf 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -10,6 +10,8 @@ using TensorAlgebra: left_orth, left_polar, lq, + orth, + polar, qr, right_null, right_orth, @@ -216,11 +218,16 @@ end labels_P = (:d, :c) Acopy = deepcopy(A) - W, P = left_polar(A, labels_A, labels_W, labels_P) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) - @test A ≈ A′ - @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + for (W, P) in ( + left_polar(A, labels_A, labels_W, labels_P), + polar(A, labels_A, labels_W, labels_P; side=:left), + polar(A, labels_A, labels_W, labels_P), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ A′ + @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Right polar ($T)" for T in elts @@ -230,11 +237,15 @@ end labels_W = (:d, :c) Acopy = deepcopy(A) - P, W = right_polar(A, labels_A, labels_P, labels_W) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) - @test A ≈ A′ - @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + for (P, W) in ( + right_polar(A, labels_A, labels_P, labels_W), + polar(A, labels_A, labels_P, labels_W; side=:right), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ A′ + @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Left orth ($T)" for T in elts @@ -244,11 +255,16 @@ end labels_P = (:d, :c) Acopy = deepcopy(A) - W, P = left_orth(A, labels_A, labels_W, labels_P) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) - @test A ≈ A′ - @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + for (W, P) in ( + left_orth(A, labels_A, labels_W, labels_P), + orth(A, labels_A, labels_W, labels_P; side=:left), + orth(A, labels_A, labels_W, labels_P), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ A′ + @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Right orth ($T)" for T in elts @@ -258,11 +274,15 @@ end labels_W = (:d, :c) Acopy = deepcopy(A) - P, W = right_orth(A, labels_A, labels_P, labels_W) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) - @test A ≈ A′ - @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + for (P, W) in ( + right_orth(A, labels_A, labels_P, labels_W), + orth(A, labels_A, labels_P, labels_W; side=:right), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ A′ + @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "factorize ($T)" for T in elts diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl new file mode 100644 index 0000000..7a391c3 --- /dev/null +++ b/test/test_matrixalgebra.jl @@ -0,0 +1,148 @@ +using LinearAlgebra: I, diag, isposdef +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra +using Test: @test, @testset + +elts = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "TensorAlgebra.MatrixAlgebra (elt=$elt)" for elt in elts + A = randn(elt, 3, 2) + for positive in (false, true) + for (Q, R) in (MatrixAlgebra.qr(A; positive), MatrixAlgebra.qr(A; full=false, positive)) + @test A ≈ Q * R + @test size(Q) == size(A) + @test size(R) == (size(A, 2), size(A, 2)) + @test Q' * Q ≈ I + @test Q * Q' ≉ I + if positive + @test all(≥(0), real(diag(R))) + @test all(≈(0), imag(diag(R))) + end + end + end + + A = randn(elt, 3, 2) + for positive in (false, true) + Q, R = MatrixAlgebra.qr(A; full=true, positive) + @test A ≈ Q * R + @test size(Q) == (size(A, 1), size(A, 1)) + @test size(R) == size(A) + @test Q' * Q ≈ I + @test Q * Q' ≈ I + if positive + @test all(≥(0), real(diag(R))) + @test all(≈(0), imag(diag(R))) + end + end + + A = randn(elt, 2, 3) + for positive in (false, true) + for (L, Q) in (MatrixAlgebra.lq(A; positive), MatrixAlgebra.lq(A; full=false, positive)) + @test A ≈ L * Q + @test size(L) == (size(A, 1), size(A, 1)) + @test size(Q) == size(A) + @test Q * Q' ≈ I + @test Q' * Q ≉ I + if positive + @test all(≥(0), real(diag(L))) + @test all(≈(0), imag(diag(L))) + end + end + end + + A = randn(elt, 3, 2) + for positive in (false, true) + L, Q = MatrixAlgebra.lq(A; full=true, positive) + @test A ≈ L * Q + @test size(L) == size(A) + @test size(Q) == (size(A, 2), size(A, 2)) + @test Q * Q' ≈ I + @test Q' * Q ≈ I + if positive + @test all(≥(0), real(diag(L))) + @test all(≈(0), imag(diag(L))) + end + end + + A = randn(elt, 3, 2) + for (W, C) in (MatrixAlgebra.orth(A), MatrixAlgebra.orth(A; side=:left)) + @test A ≈ W * C + @test size(W) == size(A) + @test size(C) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + end + + A = randn(elt, 2, 3) + C, W = MatrixAlgebra.orth(A; side=:right) + @test A ≈ C * W + @test size(C) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I + + A = randn(elt, 3, 2) + for (W, P) in (MatrixAlgebra.polar(A), MatrixAlgebra.polar(A; side=:left)) + @test A ≈ W * P + @test size(W) == size(A) + @test size(P) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + @test isposdef(P) + end + + A = randn(elt, 2, 3) + P, W = MatrixAlgebra.polar(A; side=:right) + @test A ≈ P * W + @test size(P) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I + @test isposdef(P) + + A = randn(elt, 3, 2) + for (W, C) in (MatrixAlgebra.factorize(A), MatrixAlgebra.factorize(A; orth=:left)) + @test A ≈ W * C + @test size(W) == size(A) + @test size(C) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + end + + A = randn(elt, 2, 3) + C, W = MatrixAlgebra.factorize(A; orth=:right) + @test A ≈ C * W + @test size(C) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I + + A = randn(elt, 3, 3) + D, V = MatrixAlgebra.eigen(A) + @test A * V ≈ V * D + @test MatrixAlgebra.eigvals(A) ≈ diag(D) + + A = randn(elt, 3, 2) + for (U, S, V) in (MatrixAlgebra.svd(A), MatrixAlgebra.svd(A; full=false)) + @test A ≈ U * S * V + @test size(U) == size(A) + @test size(S) == (size(A, 2), size(A, 2)) + @test size(V) == (size(A, 2), size(A, 2)) + @test U' * U ≈ I + @test U * U' ≉ I + @test V * V' ≈ I + @test V' * V ≈ I + @test MatrixAlgebra.svdvals(A) ≈ diag(S) + end + + A = randn(elt, 3, 2) + U, S, V = MatrixAlgebra.svd(A; full=true) + @test A ≈ U * S * V + @test size(U) == (size(A, 1), size(A, 1)) + @test size(S) == size(A) + @test size(V) == (size(A, 2), size(A, 2)) + @test U' * U ≈ I + @test U * U' ≈ I + @test V * V' ≈ I + @test V' * V ≈ I + @test MatrixAlgebra.svdvals(A) ≈ diag(S) +end