From 16aca55d6af46448b871b1d40fedd56fffabb9c0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 5 Apr 2025 17:33:18 -0400 Subject: [PATCH 1/7] More matrix factorizations --- src/MatrixAlgebra.jl | 110 ++++++++++++++++++++++++++++++++++++++++++ src/TensorAlgebra.jl | 2 + src/factorizations.jl | 83 ++++++++----------------------- 3 files changed, 131 insertions(+), 64 deletions(-) create mode 100644 src/MatrixAlgebra.jl diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl new file mode 100644 index 0000000..6f3fbd5 --- /dev/null +++ b/src/MatrixAlgebra.jl @@ -0,0 +1,110 @@ +module MatrixAlgebra + +using LinearAlgebra: LinearAlgebra +using MatrixAlgebraKit: + eig_full, + eig_full!, + eig_trunc, + eig_trunc!, + eig_vals, + eig_vals!, + eigh_full, + eigh_full!, + eigh_trunc, + eigh_trunc!, + eigh_vals, + eigh_vals!, + left_orth, + left_orth!, + lq_full, + lq_full!, + lq_compact, + lq_compact!, + qr_full, + qr_full!, + qr_compact, + qr_compact!, + right_orth, + right_orth!, + svd_full, + svd_full!, + svd_compact, + svd_compact!, + svd_trunc, + svd_trunc! + +for (f, f_full, f_compact) in ( + (:qr, :qr_full, :qr_compact), + (:qr!, :qr_full!, :qr_compact!), + (:lq, :lq_full, :lq_compact), + (:lq!, :lq_full!, :lq_compact!), +) + @eval begin + function $f(A::AbstractMatrix; full::Bool=false, kwargs...) + f = full ? $f_full : $f_compact + return f(A; kwargs...) + end + end +end + +for (eigen, eigh_full, eig_full, eigh_trunc, eig_trunc) in ( + (:eigen, :eigh_full, :eig_full, :eigh_trunc, :eig_trunc), + (:eigen!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!), +) + @eval begin + function $eigen(A::AbstractMatrix; trunc=nothing, ishermitian=nothing, kwargs...) + ishermitian = @something ishermitian LinearAlgebra.ishermitian(A) + f = if !isnothing(trunc) + ishermitian ? $eigh_trunc : $eig_trunc + else + ishermitian ? $eigh_full : $eig_full + end + return f(A; kwargs...) + end + end +end + +for (eigvals, eigh_vals, eig_vals) in + ((:eigvals, :eigh_vals, :eig_vals), (:eigvals!, :eigh_vals!, :eig_vals!)) + @eval begin + function $eigvals(A::AbstractMatrix; ishermitian=nothing, kwargs...) + ishermitian = @something ishermitian LinearAlgebra.ishermitian(A) + f = (ishermitian ? $eigh_vals : $eig_vals) + return f(A; kwargs...) + end + end +end + +for (svd, svd_trunc, svd_full, svd_compact) in ( + (:svd, :svd_trunc, :svd_full, :svd_compact), + (:svd!, :svd_trunc!, :svd_full!, :svd_compact!), +) + @eval begin + function $svd(A::AbstractMatrix; full::Bool=false, trunc=nothing, kwargs...) + return if !isnothing(trunc) + @assert !full "Specified both full and truncation, currently not supported" + $svd_trunc(A; trunc, kwargs...) + else + (full ? $svd_full : $svd_compact)(A; kwargs...) + end + end + end +end + +for (factorize, left_orth, right_orth) in + ((:factorize, :left_orth, :right_orth), (:factorize!, :left_orth!, :right_orth!)) + @eval begin + function $factorize(A::AbstractMatrix; orth=:left, kwargs...) + f = if orth == :left + $left_orth + elseif orth == :right + $right_orth + else + throw(ArgumentError("`orth=$orth` not supported.")) + end + return f(A; kwargs...) + end + end +end + +end diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 591133b..5fdce3f 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -2,6 +2,8 @@ module TensorAlgebra export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, svdvals +include("MatrixAlgebra.jl") +using .MatrixAlgebra: MatrixAlgebra include("blockedtuple.jl") include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") diff --git a/src/factorizations.jl b/src/factorizations.jl index 31aac73..1493d40 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -1,25 +1,6 @@ using LinearAlgebra: LinearAlgebra -using MatrixAlgebraKit: - eig_full!, - eig_trunc!, - eig_vals!, - eigh_full!, - eigh_trunc!, - eigh_vals!, - left_null!, - left_orth!, - left_polar!, - lq_full!, - lq_compact!, - qr_full!, - qr_compact!, - right_null!, - right_orth!, - right_polar!, - svd_full!, - svd_compact!, - svd_trunc!, - svd_vals! +using .MatrixAlgebra: MatrixAlgebra +using MatrixAlgebraKit: MatrixAlgebraKit """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R @@ -41,12 +22,12 @@ function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs.. biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return qr(A, biperm; kwargs...) end -function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...) +function qr(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) # tensor to matrix A_mat = fusedims(A, biperm) # factorization - Q, R = full ? qr_full!(A_mat; kwargs...) : qr_compact!(A_mat; kwargs...) + Q, R = MatrixAlgebra.qr(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -75,12 +56,12 @@ function lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs.. biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return lq(A, biperm; kwargs...) end -function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...) +function lq(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) # tensor to matrix A_mat = fusedims(A, biperm) # factorization - L, Q = (full ? lq_full! : lq_compact!)(A_mat; kwargs...) + L, Q = MatrixAlgebra.lq(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -111,25 +92,12 @@ function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwarg biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return eigen(A, biperm; kwargs...) end -function eigen( - A::AbstractArray, - biperm::BlockedPermutation{2}; - trunc=nothing, - ishermitian=nothing, - kwargs..., -) +function eigen(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) # tensor to matrix A_mat = fusedims(A, biperm) - ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat) - # factorization - f! = if !isnothing(trunc) - ishermitian ? eigh_trunc! : eig_trunc! - else - ishermitian ? eigh_full! : eig_full! - end - D, V = f!(A_mat; kwargs...) + D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) # matrix to tensor axes_codomain, = blockpermute(axes(A), biperm) @@ -161,11 +129,9 @@ function eigvals( A::AbstractArray, biperm::BlockedPermutation{2}; ishermitian=nothing, kwargs... ) A_mat = fusedims(A, biperm) - ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat) - return (ishermitian ? eigh_vals! : eig_vals!)(A_mat; kwargs...) + return MatrixAlgebra.eigvals!(A_mat; kwargs...) end -# TODO: separate out the algorithm selection step from the implementation """ svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ @@ -187,23 +153,12 @@ function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs. biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return svd(A, biperm; kwargs...) end -function svd( - A::AbstractArray, - biperm::BlockedPermutation{2}; - full::Bool=false, - trunc=nothing, - kwargs..., -) +function svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) # tensor to matrix A_mat = fusedims(A, biperm) # factorization - if !isnothing(trunc) - @assert !full "Specified both full and truncation, currently not supported" - U, S, Vᴴ = svd_trunc!(A_mat; trunc, kwargs...) - else - U, S, Vᴴ = full ? svd_full!(A_mat; kwargs...) : svd_compact!(A_mat; kwargs...) - end + U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -228,7 +183,7 @@ function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) end function svdvals(A::AbstractArray, biperm::BlockedPermutation{2}) A_mat = fusedims(A, biperm) - return svd_vals!(A_mat) + return MatrixAlgebraKit.svd_vals!(A_mat) end """ @@ -254,7 +209,7 @@ function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; k end function left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) - N = left_null!(A_mat; kwargs...) + N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) axes_codomain, _ = blockpermute(axes(A), biperm) axes_N = (axes_codomain..., axes(N, 2)) N_tensor = splitdims(N, axes_N) @@ -284,7 +239,7 @@ function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; end function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) - Nᴴ = right_null!(A_mat; kwargs...) + Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) _, axes_domain = blockpermute(axes(A), biperm) axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...) return splitdims(Nᴴ, axes_Nᴴ) @@ -313,7 +268,7 @@ function left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) # factorization - W, P = left_polar!(A_mat; kwargs...) + W, P = MatrixAlgebraKit.left_polar!(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -345,7 +300,7 @@ function right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) # factorization - P, W = right_polar!(A_mat; kwargs...) + P, W = MatrixAlgebraKit.right_polar!(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -377,7 +332,7 @@ function left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) # factorization - V, C = left_orth!(A_mat; kwargs...) + V, C = MatrixAlgebraKit.left_orth!(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -409,7 +364,7 @@ function right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) # factorization - P, W = right_orth!(A_mat; kwargs...) + P, W = MatrixAlgebraKit.right_orth!(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -441,7 +396,7 @@ function factorize(A::AbstractArray, biperm::BlockedPermutation{2}; orth=:left, A_mat = fusedims(A, biperm) # factorization - X, Y = (orth == :left ? left_orth! : right_orth!)(A_mat; kwargs...) + X, Y = MatrixAlgebra.factorize!(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) From 14c971d039cb857a79be61c9c6c22b2d8284df38 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 5 Apr 2025 18:06:24 -0400 Subject: [PATCH 2/7] More matrix factorizations --- Project.toml | 2 +- src/MatrixAlgebra.jl | 43 ++++- src/factorizations.jl | 308 ++++++++++++++---------------------- test/test_factorizations.jl | 60 ++++--- 4 files changed, 195 insertions(+), 218 deletions(-) diff --git a/Project.toml b/Project.toml index cf887d5..9fa42f7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.2.9" +version = "0.2.10" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 6f3fbd5..ab2bd58 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -16,6 +16,8 @@ using MatrixAlgebraKit: eigh_vals!, left_orth, left_orth!, + left_polar, + left_polar!, lq_full, lq_full!, lq_compact, @@ -26,6 +28,8 @@ using MatrixAlgebraKit: qr_compact!, right_orth, right_orth!, + right_polar, + right_polar!, svd_full, svd_full!, svd_compact, @@ -91,14 +95,43 @@ for (svd, svd_trunc, svd_full, svd_compact) in ( end end -for (factorize, left_orth, right_orth) in - ((:factorize, :left_orth, :right_orth), (:factorize!, :left_orth!, :right_orth!)) +for (polar, left_polar, right_polar) in + ((:polar, :left_polar, :right_polar), (:polar!, :left_polar!, :right_polar!)) @eval begin - function $factorize(A::AbstractMatrix; orth=:left, kwargs...) - f = if orth == :left + function $polar(A::AbstractMatrix; side=:left, kwargs...) + f = if side == :left + $left_polar + elseif side == :right + $right_polar + else + throw(ArgumentError("`side=$side` not supported.")) + end + return f(A; kwargs...) + end + end +end + +for (orth, left_orth, right_orth) in + ((:orth, :left_orth, :right_orth), (:orth!, :left_orth!, :right_orth!)) + @eval begin + function $orth(A::AbstractMatrix; side=:left, kwargs...) + f = if side == :left $left_orth - elseif orth == :right + elseif side == :right $right_orth + else + throw(ArgumentError("`side=$side` not supported.")) + end + return f(A; kwargs...) + end + end +end + +for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, :orth!)) + @eval begin + function $factorize(A::AbstractMatrix; orth=:left, kwargs...) + f = if orth in (:left, :right) + $orth_f else throw(ArgumentError("`orth=$orth` not supported.")) end diff --git a/src/factorizations.jl b/src/factorizations.jl index 1493d40..2ce171d 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -2,6 +2,42 @@ using LinearAlgebra: LinearAlgebra using .MatrixAlgebra: MatrixAlgebra using MatrixAlgebraKit: MatrixAlgebraKit +function factorize_with(f, A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + X, Y = f(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_X = (axes_codomain..., axes(X, 2)) + axes_Y = (axes(Y, 1), axes_domain...) + return splitdims(X, axes_X), splitdims(Y, axes_Y) +end + +for (f, f_mat) in ( + (:qr, :(MatrixAlgebra.qr)), + (:lq, :(MatrixAlgebra.lq)), + (:left_polar, :(MatrixAlgebra.left_polar)), + (:right_polar, :(MatrixAlgebra.right_polar)), + (:polar, :(MatrixAlgebra.polar)), + (:left_orth, :(MatrixAlgebra.left_orth)), + (:right_orth, :(MatrixAlgebra.right_orth)), + (:orth, :(MatrixAlgebra.orth)), + (:factorize, :(MatrixAlgebra.factorize)), +) + @eval begin + function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return $f(A, biperm; kwargs...) + end + function $f(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) + return factorize_with($f_mat, A, biperm; kwargs...) + end + end +end + """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R qr(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> Q, R @@ -18,23 +54,7 @@ their labels, or directly through a `biperm`. See also `MatrixAlgebraKit.qr_full!` and `MatrixAlgebraKit.qr_compact!`. """ -function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return qr(A, biperm; kwargs...) -end -function qr(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - Q, R = MatrixAlgebra.qr(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_Q = (axes_codomain..., axes(Q, 2)) - axes_R = (axes(R, 1), axes_domain...) - return splitdims(Q, axes_Q), splitdims(R, axes_R) -end +qr """ lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q @@ -52,23 +72,87 @@ their labels, or directly through a `biperm`. See also `MatrixAlgebraKit.lq_full!` and `MatrixAlgebraKit.lq_compact!`. """ -function lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return lq(A, biperm; kwargs...) -end -function lq(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) +lq - # factorization - L, Q = MatrixAlgebra.lq(A_mat; kwargs...) +""" + left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P + left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_L = (axes_codomain..., axes(L, ndims(L))) - axes_Q = (axes(Q, 1), axes_domain...) - return splitdims(L, axes_L), splitdims(Q, axes_Q) -end +Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.left_polar!`. +""" +left_polar + +""" + right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W + right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W + +Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.right_polar!`. +""" +right_polar + +""" + left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C + left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C + +Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.left_orth!`. +""" +left_orth + +""" + right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V + right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V + +Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.right_orth!`. +""" +right_orth + +""" + factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y + factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y + +Compute the decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- `orth::Symbol=:left`: specify the orthogonality of the decomposition. + Currently only `:left` and `:right` are supported. +- Other keywords are passed on directly to MatrixAlgebraKit. +""" +factorize """ eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V @@ -244,163 +328,3 @@ function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...) return splitdims(Nᴴ, axes_Nᴴ) end - -""" - left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P - left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P - -Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- Keyword arguments are passed on directly to MatrixAlgebraKit. - -See also `MatrixAlgebraKit.left_polar!`. -""" -function left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return left_polar(A, biperm; kwargs...) -end -function left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - W, P = MatrixAlgebraKit.left_polar!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_W = (axes_codomain..., axes(W, 2)) - axes_P = (axes(P, 1), axes_domain...) - return splitdims(W, axes_W), splitdims(P, axes_P) -end - -""" - right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W - right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W - -Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- Keyword arguments are passed on directly to MatrixAlgebraKit. - -See also `MatrixAlgebraKit.right_polar!`. -""" -function right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return right_polar(A, biperm; kwargs...) -end -function right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - P, W = MatrixAlgebraKit.right_polar!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_P = (axes_codomain..., axes(P, ndims(P))) - axes_W = (axes(W, 1), axes_domain...) - return splitdims(P, axes_P), splitdims(W, axes_W) -end - -""" - left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C - left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C - -Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- Keyword arguments are passed on directly to MatrixAlgebraKit. - -See also `MatrixAlgebraKit.left_orth!`. -""" -function left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return left_orth(A, biperm; kwargs...) -end -function left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - V, C = MatrixAlgebraKit.left_orth!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_V = (axes_codomain..., axes(V, 2)) - axes_C = (axes(C, 1), axes_domain...) - return splitdims(V, axes_V), splitdims(C, axes_C) -end - -""" - right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V - right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V - -Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- Keyword arguments are passed on directly to MatrixAlgebraKit. - -See also `MatrixAlgebraKit.right_orth!`. -""" -function right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return right_orth(A, biperm; kwargs...) -end -function right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - P, W = MatrixAlgebraKit.right_orth!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_P = (axes_codomain..., axes(P, ndims(P))) - axes_W = (axes(W, 1), axes_domain...) - return splitdims(P, axes_P), splitdims(W, axes_W) -end - -""" - factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y - factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y - -Compute the decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via -their labels, or directly through a `biperm`. - -## Keyword arguments - -- `orth::Symbol=:left`: specify the orthogonality of the decomposition. - Currently only `:left` and `:right` are supported. -- Other keywords are passed on directly to MatrixAlgebraKit. -""" -function factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return factorize(A, biperm; kwargs...) -end -function factorize(A::AbstractArray, biperm::BlockedPermutation{2}; orth=:left, kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - X, Y = MatrixAlgebra.factorize!(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_X = (axes_codomain..., axes(X, ndims(X))) - axes_Y = (axes(Y, 1), axes_domain...) - return splitdims(X, axes_X), splitdims(Y, axes_Y) -end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 62b1561..0c5ccdf 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -10,6 +10,8 @@ using TensorAlgebra: left_orth, left_polar, lq, + orth, + polar, qr, right_null, right_orth, @@ -216,11 +218,16 @@ end labels_P = (:d, :c) Acopy = deepcopy(A) - W, P = left_polar(A, labels_A, labels_W, labels_P) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) - @test A ≈ A′ - @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + for (W, P) in ( + left_polar(A, labels_A, labels_W, labels_P), + polar(A, labels_A, labels_W, labels_P; side=:left), + polar(A, labels_A, labels_W, labels_P), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ A′ + @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Right polar ($T)" for T in elts @@ -230,11 +237,15 @@ end labels_W = (:d, :c) Acopy = deepcopy(A) - P, W = right_polar(A, labels_A, labels_P, labels_W) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) - @test A ≈ A′ - @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + for (P, W) in ( + right_polar(A, labels_A, labels_P, labels_W), + polar(A, labels_A, labels_P, labels_W; side=:right), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ A′ + @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Left orth ($T)" for T in elts @@ -244,11 +255,16 @@ end labels_P = (:d, :c) Acopy = deepcopy(A) - W, P = left_orth(A, labels_A, labels_W, labels_P) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) - @test A ≈ A′ - @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + for (W, P) in ( + left_orth(A, labels_A, labels_W, labels_P), + orth(A, labels_A, labels_W, labels_P; side=:left), + orth(A, labels_A, labels_W, labels_P), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ A′ + @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "Right orth ($T)" for T in elts @@ -258,11 +274,15 @@ end labels_W = (:d, :c) Acopy = deepcopy(A) - P, W = right_orth(A, labels_A, labels_P, labels_W) - @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) - @test A ≈ A′ - @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + for (P, W) in ( + right_orth(A, labels_A, labels_P, labels_W), + orth(A, labels_A, labels_P, labels_W; side=:right), + ) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ A′ + @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end end @testset "factorize ($T)" for T in elts From e1edd712a3f311b9bcb6d02d6c6e48fc899c5e82 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 7 Apr 2025 08:15:58 -0400 Subject: [PATCH 3/7] Simplifications based on Lukas's comments --- src/MatrixAlgebra.jl | 36 +----------------------------------- src/TensorAlgebra.jl | 1 - src/factorizations.jl | 43 ++++++++++++++----------------------------- 3 files changed, 15 insertions(+), 65 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index ab2bd58..6b2582b 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -1,41 +1,7 @@ module MatrixAlgebra using LinearAlgebra: LinearAlgebra -using MatrixAlgebraKit: - eig_full, - eig_full!, - eig_trunc, - eig_trunc!, - eig_vals, - eig_vals!, - eigh_full, - eigh_full!, - eigh_trunc, - eigh_trunc!, - eigh_vals, - eigh_vals!, - left_orth, - left_orth!, - left_polar, - left_polar!, - lq_full, - lq_full!, - lq_compact, - lq_compact!, - qr_full, - qr_full!, - qr_compact, - qr_compact!, - right_orth, - right_orth!, - right_polar, - right_polar!, - svd_full, - svd_full!, - svd_compact, - svd_compact!, - svd_trunc, - svd_trunc! +using MatrixAlgebraKit for (f, f_full, f_compact) in ( (:qr, :qr_full, :qr_compact), diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 5fdce3f..a2ff401 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -3,7 +3,6 @@ module TensorAlgebra export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, svdvals include("MatrixAlgebra.jl") -using .MatrixAlgebra: MatrixAlgebra include("blockedtuple.jl") include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") diff --git a/src/factorizations.jl b/src/factorizations.jl index 2ce171d..57243e8 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -1,31 +1,8 @@ using LinearAlgebra: LinearAlgebra -using .MatrixAlgebra: MatrixAlgebra using MatrixAlgebraKit: MatrixAlgebraKit -function factorize_with(f, A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - # tensor to matrix - A_mat = fusedims(A, biperm) - - # factorization - X, Y = f(A_mat; kwargs...) - - # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_X = (axes_codomain..., axes(X, 2)) - axes_Y = (axes(Y, 1), axes_domain...) - return splitdims(X, axes_X), splitdims(Y, axes_Y) -end - -for (f, f_mat) in ( - (:qr, :(MatrixAlgebra.qr)), - (:lq, :(MatrixAlgebra.lq)), - (:left_polar, :(MatrixAlgebra.left_polar)), - (:right_polar, :(MatrixAlgebra.right_polar)), - (:polar, :(MatrixAlgebra.polar)), - (:left_orth, :(MatrixAlgebra.left_orth)), - (:right_orth, :(MatrixAlgebra.right_orth)), - (:orth, :(MatrixAlgebra.orth)), - (:factorize, :(MatrixAlgebra.factorize)), +for f in ( + :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize ) @eval begin function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) @@ -33,7 +10,17 @@ for (f, f_mat) in ( return $f(A, biperm; kwargs...) end function $f(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - return factorize_with($f_mat, A, biperm; kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + X, Y = MatrixAlgebra.$f(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_X = (axes_codomain..., axes(X, 2)) + axes_Y = (axes(Y, 1), axes_domain...) + return splitdims(X, axes_X), splitdims(Y, axes_Y) end end end @@ -209,9 +196,7 @@ function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwa biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return eigvals(A, biperm; kwargs...) end -function eigvals( - A::AbstractArray, biperm::BlockedPermutation{2}; ishermitian=nothing, kwargs... -) +function eigvals(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) return MatrixAlgebra.eigvals!(A_mat; kwargs...) end From 9e02367b08949607afc37ef285c37b6d142feb1f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 7 Apr 2025 08:19:23 -0400 Subject: [PATCH 4/7] Update export list --- src/TensorAlgebra.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index a2ff401..1bc23e4 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,6 +1,22 @@ module TensorAlgebra -export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, svdvals +export contract, + contract!, + eigen, + eigvals, + factorize, + left_null, + left_orth, + left_polar, + lq, + qr, + right_null, + right_orth, + right_polar, + orth, + polar, + svd, + svdvals include("MatrixAlgebra.jl") include("blockedtuple.jl") From e68a1356451a000bd3999be55a2c9563a3e36738 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 7 Apr 2025 08:50:39 -0400 Subject: [PATCH 5/7] Exports --- src/MatrixAlgebra.jl | 27 +++++++++++++++++++++++++++ src/factorizations.jl | 2 +- test/test_exports.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 6b2582b..de4463c 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -1,5 +1,24 @@ module MatrixAlgebra +export eigen, + eigen!, + eigvals, + eigvals!, + factorize, + factorize!, + lq, + lq!, + orth, + orth!, + polar, + polar!, + qr, + qr!, + svd, + svd!, + svdvals, + svdvals! + using LinearAlgebra: LinearAlgebra using MatrixAlgebraKit @@ -61,6 +80,14 @@ for (svd, svd_trunc, svd_full, svd_compact) in ( end end +for (svdvals, svd_vals) in ((:svdvals, :svd_vals), (:svdvals!, :svd_vals!)) + @eval begin + function $svdvals(A::AbstractMatrix; ishermitian=nothing, kwargs...) + return $svd_vals(A; kwargs...) + end + end +end + for (polar, left_polar, right_polar) in ((:polar, :left_polar, :right_polar), (:polar!, :left_polar!, :right_polar!)) @eval begin diff --git a/src/factorizations.jl b/src/factorizations.jl index 57243e8..a2fee2d 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -252,7 +252,7 @@ function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) end function svdvals(A::AbstractArray, biperm::BlockedPermutation{2}) A_mat = fusedims(A, biperm) - return MatrixAlgebraKit.svd_vals!(A_mat) + return MatrixAlgebra.svdvals!(A_mat) end """ diff --git a/test/test_exports.jl b/test/test_exports.jl index dbef64f..8dba1bf 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -7,12 +7,42 @@ using Test: @test, @testset :contract!, :eigen, :eigvals, + :factorize, :left_null, + :left_orth, + :left_polar, :lq, + :orth, + :polar, :qr, :right_null, + :right_orth, + :right_polar, :svd, :svdvals, ] @test issetequal(names(TensorAlgebra), exports) + + exports = [ + :MatrixAlgebra, + :eigen, + :eigen!, + :eigvals, + :eigvals!, + :factorize, + :factorize!, + :lq, + :lq!, + :orth, + :orth!, + :polar, + :polar!, + :qr, + :qr!, + :svd, + :svd!, + :svdvals, + :svdvals!, + ] + @test issetequal(names(TensorAlgebra.MatrixAlgebra), exports) end From 94aad1d487cef93831be896fb62ee60f96c4fb2c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 7 Apr 2025 09:38:08 -0400 Subject: [PATCH 6/7] Add tests for TensorAlgebra.MatrixAlgebra --- src/MatrixAlgebra.jl | 2 +- test/test_matrixalgebra.jl | 148 +++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 test/test_matrixalgebra.jl diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index de4463c..a3796b9 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -128,7 +128,7 @@ for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, : else throw(ArgumentError("`orth=$orth` not supported.")) end - return f(A; kwargs...) + return f(A; side=orth, kwargs...) end end end diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl new file mode 100644 index 0000000..8ee73b8 --- /dev/null +++ b/test/test_matrixalgebra.jl @@ -0,0 +1,148 @@ +using LinearAlgebra: I, diag +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra +using Test: @test, @testset + +elts = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "elt=$elt" for elt in elts + A = randn(elt, 3, 2) + for positive in (false, true) + for (Q, R) in (MatrixAlgebra.qr(A; positive), MatrixAlgebra.qr(A; full=false, positive)) + @test A ≈ Q * R + @test size(Q) == size(A) + @test size(R) == (size(A, 2), size(A, 2)) + @test Q' * Q ≈ I + @test Q * Q' ≉ I + if positive + @test all(≥(0), real(diag(R))) + @test all(≈(0), imag(diag(R))) + end + end + end + + A = randn(elt, 3, 2) + for positive in (false, true) + Q, R = MatrixAlgebra.qr(A; full=true, positive) + @test A ≈ Q * R + @test size(Q) == (size(A, 1), size(A, 1)) + @test size(R) == size(A) + @test Q' * Q ≈ I + @test Q * Q' ≈ I + if positive + @test all(≥(0), real(diag(R))) + @test all(≈(0), imag(diag(R))) + end + end + + A = randn(elt, 2, 3) + for positive in (false, true) + for (L, Q) in (MatrixAlgebra.lq(A; positive), MatrixAlgebra.lq(A; full=false, positive)) + @test A ≈ L * Q + @test size(L) == (size(A, 1), size(A, 1)) + @test size(Q) == size(A) + @test Q * Q' ≈ I + @test Q' * Q ≉ I + if positive + @test all(≥(0), real(diag(L))) + @test all(≈(0), imag(diag(L))) + end + end + end + + A = randn(elt, 3, 2) + for positive in (false, true) + L, Q = MatrixAlgebra.lq(A; full=true, positive) + @test A ≈ L * Q + @test size(L) == size(A) + @test size(Q) == (size(A, 2), size(A, 2)) + @test Q * Q' ≈ I + @test Q' * Q ≈ I + if positive + @test all(≥(0), real(diag(L))) + @test all(≈(0), imag(diag(L))) + end + end + + A = randn(elt, 3, 2) + for (W, C) in (MatrixAlgebra.orth(A), MatrixAlgebra.orth(A; side=:left)) + @test A ≈ W * C + @test size(W) == size(A) + @test size(C) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + end + + A = randn(elt, 2, 3) + C, W = MatrixAlgebra.orth(A; side=:right) + @test A ≈ C * W + @test size(C) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I + + A = randn(elt, 3, 2) + for (W, P) in (MatrixAlgebra.polar(A), MatrixAlgebra.polar(A; side=:left)) + @test A ≈ W * P + @test size(W) == size(A) + @test size(P) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + @test isposdef(P) + end + + A = randn(elt, 2, 3) + P, W = MatrixAlgebra.polar(A; side=:right) + @test A ≈ P * W + @test size(P) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I + @test isposdef(P) + + A = randn(elt, 3, 2) + for (W, C) in (MatrixAlgebra.factorize(A), MatrixAlgebra.factorize(A; orth=:left)) + @test A ≈ W * C + @test size(W) == size(A) + @test size(C) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + end + + A = randn(elt, 2, 3) + C, W = MatrixAlgebra.factorize(A; orth=:right) + @test A ≈ C * W + @test size(C) == (size(A, 1), size(A, 1)) + @test size(W) == size(A) + @test W * W' ≈ I + @test W' * W ≉ I + + A = randn(elt, 3, 3) + D, V = MatrixAlgebra.eigen(A) + @test A * V ≈ V * D + @test MatrixAlgebra.eigvals(A) ≈ diag(D) + + A = randn(elt, 3, 2) + for (U, S, V) in (MatrixAlgebra.svd(A), MatrixAlgebra.svd(A; full=false)) + @test A ≈ U * S * V + @test size(U) == size(A) + @test size(S) == (size(A, 2), size(A, 2)) + @test size(V) == (size(A, 2), size(A, 2)) + @test U' * U ≈ I + @test U * U' ≉ I + @test V * V' ≈ I + @test V' * V ≈ I + @test MatrixAlgebra.svdvals(A) ≈ diag(S) + end + + A = randn(elt, 3, 2) + U, S, V = MatrixAlgebra.svd(A; full=true) + @test A ≈ U * S * V + @test size(U) == (size(A, 1), size(A, 1)) + @test size(S) == size(A) + @test size(V) == (size(A, 2), size(A, 2)) + @test U' * U ≈ I + @test U * U' ≈ I + @test V * V' ≈ I + @test V' * V ≈ I + @test MatrixAlgebra.svdvals(A) ≈ diag(S) +end From 89c8d33da38a4669d12eb8a9030268d0fedcc1dd Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 7 Apr 2025 09:46:08 -0400 Subject: [PATCH 7/7] Add missing import --- test/test_matrixalgebra.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 8ee73b8..7a391c3 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -1,10 +1,10 @@ -using LinearAlgebra: I, diag +using LinearAlgebra: I, diag, isposdef using TensorAlgebra.MatrixAlgebra: MatrixAlgebra using Test: @test, @testset elts = (Float32, Float64, ComplexF32, ComplexF64) -@testset "elt=$elt" for elt in elts +@testset "TensorAlgebra.MatrixAlgebra (elt=$elt)" for elt in elts A = randn(elt, 3, 2) for positive in (false, true) for (Q, R) in (MatrixAlgebra.qr(A; positive), MatrixAlgebra.qr(A; full=false, positive))