diff --git a/Project.toml b/Project.toml index 90a3c86..bf89042 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.5.1" +version = "0.5.2" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index d3e3658..6fdc30c 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -75,14 +75,34 @@ for (svd, svd_trunc, svd_full, svd_compact) in ( (:svd, :svd_trunc, :svd_full, :svd_compact), (:svd!, :svd_trunc!, :svd_full!, :svd_compact!), ) + _svd = Symbol(:_, svd) @eval begin - function $svd(A::AbstractMatrix; full::Bool = false, trunc = nothing, kwargs...) - return if !isnothing(trunc) - @assert !full "Specified both full and truncation, currently not supported" - $svd_trunc(A; trunc, kwargs...) - else - (full ? $svd_full : $svd_compact)(A; kwargs...) - end + function $svd( + A::AbstractMatrix; + full::Union{Bool, Val} = Val(false), + trunc = nothing, + kwargs..., + ) + return $_svd(full, trunc, A; kwargs...) + end + function $_svd(full::Bool, trunc, A::AbstractMatrix; kwargs...) + return $_svd(Val(full), trunc, A; kwargs...) + end + function $_svd(full::Val{false}, trunc::Nothing, A::AbstractMatrix; kwargs...) + return $svd_compact(A; kwargs...) + end + function $_svd(full::Val{false}, trunc, A::AbstractMatrix; kwargs...) + return $svd_trunc(A; trunc, kwargs...) + end + function $_svd(full::Val{true}, trunc::Nothing, A::AbstractMatrix; kwargs...) + return $svd_full(A; kwargs...) + end + function $_svd(full::Val{true}, trunc, A::AbstractMatrix; kwargs...) + return throw( + ArgumentError( + "Specified both full and truncation, currently not supported" + ) + ) end end end diff --git a/src/factorizations.jl b/src/factorizations.jl index 5136beb..db7862d 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -7,22 +7,38 @@ for f in ( @eval begin function $f( A::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + codomain_length::Val, domain_length::Val; kwargs..., ) # tensor to matrix - A_mat = matricize(A, codomain_perm, domain_perm) + A_mat = matricize(A, codomain_length, domain_length) # factorization X, Y = MatrixAlgebra.$f(A_mat; kwargs...) # matrix to tensor - biperm = permmortar((codomain_perm, domain_perm)) + biperm = blockedtrivialperm((codomain_length, domain_length)) axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) axes_Y = tuplemortar(((axes(Y, 1),), axes_domain)) return unmatricize(X, axes_X), unmatricize(Y, axes_Y) end + end +end + +for f in ( + :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize, + :eigen, :eigvals, :svd, :svdvals, :left_null, :right_null, + ) + @eval begin + function $f( + A::AbstractArray, + codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + kwargs..., + ) + A_perm = bipermutedims(A, codomain_perm, domain_perm) + return $f(A_perm, Val(length(codomain_perm)), Val(length(domain_perm)); kwargs...) + end function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return $f(A, blocks(biperm)...; kwargs...) @@ -36,6 +52,7 @@ end """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R qr(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> Q, R + qr(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> Q, R qr(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Q, R Compute the QR decomposition of a generic N-dimensional array, by interpreting it as @@ -55,6 +72,7 @@ qr """ lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q lq(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> L, Q + lq(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> L, Q lq(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> L, Q Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as @@ -74,6 +92,7 @@ lq """ left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P left_polar(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> W, P + left_polar(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> W, P left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> W, P Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as @@ -91,6 +110,7 @@ left_polar """ right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W right_polar(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> P, W + right_polar(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> P, W right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> P, W Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as @@ -108,6 +128,7 @@ right_polar """ left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C left_orth(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> V, C + left_orth(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> V, C left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> V, C Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as @@ -125,6 +146,7 @@ left_orth """ right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V right_orth(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> C, V + right_orth(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> C, V right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> C, V Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as @@ -142,6 +164,7 @@ right_orth """ factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y factorize(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> X, Y + factorize(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> X, Y factorize(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y Compute the decomposition of a generic N-dimensional array, by interpreting it as @@ -159,6 +182,7 @@ factorize """ eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V eigen(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> D, V + eigen(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> D, V eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D, V Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as @@ -175,26 +199,18 @@ their labels or directly through a bi-permutation. See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlgebraKit.eig_vals!`, `MatrixAlgebraKit.eigh_full!`, `MatrixAlgebraKit.eigh_trunc!`, and `MatrixAlgebraKit.eigh_vals!`. """ -function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return eigen(A, blocks(biperm)...; kwargs...) -end -function eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return eigen(A, blocks(biperm)...; kwargs...) -end function eigen( A::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + codomain_length::Val, domain_length::Val; kwargs..., ) # tensor to matrix - A_mat = matricize(A, codomain_perm, domain_perm) - + A_mat = matricize(A, codomain_length, domain_length) # factorization D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) # matrix to tensor - biperm = permmortar((codomain_perm, domain_perm)) + biperm = blockedtrivialperm((codomain_length, domain_length)) axes_codomain, = blocks(axes(A)[biperm]) axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) return D, unmatricize(V, axes_V) @@ -203,6 +219,7 @@ end """ eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D eigvals(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> D + eigvals(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> D eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D Compute the eigenvalues of a generic N-dimensional array, by interpreting it as @@ -217,25 +234,19 @@ their labels or directly through a bi-permutation. The output is a vector of eig See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`. """ -function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return eigvals(A, blocks(biperm)...; kwargs...) -end -function eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return eigvals(A, blocks(biperm)...; kwargs...) -end function eigvals( A::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + codomain_length::Val, domain_length::Val; kwargs..., ) - A_mat = matricize(A, codomain_perm, domain_perm) + A_mat = matricize(A, codomain_length, domain_length) return MatrixAlgebra.eigvals!(A_mat; kwargs...) end """ svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ svd(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> U, S, Vᴴ + svd(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> U, S, Vᴴ svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> U, S, Vᴴ Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as @@ -251,26 +262,18 @@ their labels or directly through a bi-permutation. See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `MatrixAlgebraKit.svd_trunc!`. """ -function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return svd(A, blocks(biperm)...; kwargs...) -end -function svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return svd(A, blocks(biperm)...; kwargs...) -end function svd( A::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + codomain_length::Val, domain_length::Val; kwargs..., ) # tensor to matrix - A_mat = matricize(A, codomain_perm, domain_perm) - + A_mat = matricize(A, codomain_length, domain_length) # factorization U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) # matrix to tensor - biperm = permmortar((codomain_perm, domain_perm)) + biperm = blockedtrivialperm((codomain_length, domain_length)) axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) @@ -280,6 +283,7 @@ end """ svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) -> S svdvals(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}) -> S + svdvals(A::AbstractArray, codomain_length::Val, domain_length::Val) -> S svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> S Compute the singular values of a generic N-dimensional array, by interpreting it as @@ -288,24 +292,18 @@ their labels or directly through a bi-permutation. The output is a vector of sin See also `MatrixAlgebraKit.svd_vals!`. """ -function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return svdvals(A, blocks(biperm)...) -end -function svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) - return svdvals(A, blocks(biperm)...) -end function svdvals( A::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}} + codomain_length::Val, domain_length::Val ) - A_mat = matricize(A, codomain_perm, domain_perm) + A_mat = matricize(A, codomain_length, domain_length) return MatrixAlgebra.svdvals!(A_mat) end """ left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> N left_null(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> N + left_null(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> N left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> N Compute the left nullspace of a generic N-dimensional array, by interpreting it as @@ -321,21 +319,14 @@ The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`. The options are `:qr`, `:qrpos` and `:svd`. The former two require `0 == atol == rtol`. The default is `:qrpos` if `atol == rtol == 0`, and `:svd` otherwise. """ -function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return left_null(A, blocks(biperm)...; kwargs...) -end -function left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return left_null(A, blocks(biperm)...; kwargs...) -end function left_null( A::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + codomain_length::Val, domain_length::Val; kwargs..., ) - A_mat = matricize(A, codomain_perm, domain_perm) + A_mat = matricize(A, codomain_length, domain_length) N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) - biperm = permmortar((codomain_perm, domain_perm)) + biperm = blockedtrivialperm((codomain_length, domain_length)) axes_codomain = first(blocks(axes(A)[biperm])) axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) return unmatricize(N, axes_N) @@ -344,6 +335,7 @@ end """ right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Nᴴ right_null(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> Nᴴ + right_null(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> Nᴴ right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Nᴴ Compute the right nullspace of a generic N-dimensional array, by interpreting it as @@ -359,21 +351,14 @@ The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. The options are `:lq`, `:lqpos` and `:svd`. The former two require `0 == atol == rtol`. The default is `:lqpos` if `atol == rtol == 0`, and `:svd` otherwise. """ -function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return right_null(A, blocks(biperm)...; kwargs...) -end -function right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return right_null(A, blocks(biperm)...; kwargs...) -end function right_null( A::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + codomain_length::Val, domain_length::Val; kwargs..., ) - A_mat = matricize(A, codomain_perm, domain_perm) + A_mat = matricize(A, codomain_length, domain_length) Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) - biperm = permmortar((codomain_perm, domain_perm)) + biperm = blockedtrivialperm((codomain_length, domain_length)) axes_domain = last(blocks((axes(A)[biperm]))) axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) return unmatricize(Nᴴ, axes_Nᴴ) diff --git a/src/matricize.jl b/src/matricize.jl index 9f395e1..625b508 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -39,15 +39,34 @@ function fuseaxes( end # Inner version takes a list of sub-permutations, overload this one if needed. +# TODO: Remove _permutedims once support for Julia 1.10 is dropped +# define permutedims with a BlockedPermuation. Default is to flatten it. +# TODO: Deprecate `permuteblockeddims` in favor of `bipermutedims`. +# Keeping it here for backwards compatibility. +function bipermutedims(a::AbstractArray, perm1, perm2) + return permuteblockeddims(a, perm1, perm2) +end +function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) + return permuteblockeddims!(a_dest, a_src, perm1, perm2) +end +function bipermutedims(a::AbstractArray, biperm::AbstractBlockPermutation{2}) + return permuteblockeddims(a, biperm) +end +function bipermutedims!( + a_dest::AbstractArray, a_src::AbstractArray, biperm::AbstractBlockPermutation{2} + ) + return permuteblockeddims!(a_dest, a_src, biperm) +end + +# Older interface. +# TODO: Deprecate in favor of `bipermutedims` (or decide if we want to keep it +# in case there are applications of more general partitionings). function permuteblockeddims(a::AbstractArray, perm1, perm2) return _permutedims(a, (perm1..., perm2...)) end function permuteblockeddims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) return _permutedims!(a_dest, a_src, (perm1..., perm2...)) end - -# TODO remove _permutedims once support for Julia 1.10 is dropped -# define permutedims with a BlockedPermuation. Default is to flatten it. function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation{2}) return permuteblockeddims(a, blocks(biperm)...) end @@ -87,7 +106,7 @@ function matricize( ) where {N1, N2} ndims(a) == length(permblock1) + length(permblock2) || throw(ArgumentError("Invalid bipermutation")) - a_perm = permuteblockeddims(a, permblock1, permblock2) + a_perm = bipermutedims(a, permblock1, permblock2) return matricize(style, a_perm, Val(length(permblock1)), Val(length(permblock2))) end @@ -179,7 +198,7 @@ function unmatricize( blocked_axes = axes_dest[invbiperm] a12 = unmatricize(style, m, blocked_axes) biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest)) - return permuteblockeddims(a12, biperm_dest) + return bipermutedims(a12, biperm_dest) end function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2}) @@ -187,7 +206,7 @@ function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermu end function unmatricize( style::FusionStyle, m::AbstractMatrix, axes_dest, - invbiperm::AbstractBlockPermutation{2} + invbiperm::AbstractBlockPermutation{2}, ) return unmatricize(style, m, axes_dest, blocks(invbiperm)...) end @@ -208,7 +227,7 @@ function unmatricize!( blocked_axes = axes(a_dest)[invbiperm] a_perm = unmatricize(style, m, blocked_axes) biperm_dest = biperm(invperm(invbiperm), length_codomain(axes(a_dest))) - return permuteblockeddims!(a_dest, a_perm, biperm_dest) + return bipermutedims!(a_dest, a_perm, biperm_dest) end function unmatricize!( diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index 8838978..883c5be 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -35,14 +35,22 @@ for f in MATRIX_FUNCTIONS @eval begin function $f( a::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + codomain_length::Val, domain_length::Val; kwargs..., ) - a_mat = matricize(a, codomain_perm, domain_perm) + a_mat = matricize(a, codomain_length, domain_length) fa_mat = Base.$f(a_mat; kwargs...) - biperm = permmortar((codomain_perm, domain_perm)) + biperm = blockedtrivialperm((codomain_length, domain_length)) return unmatricize(fa_mat, axes(a)[biperm]) end + function $f( + a::AbstractArray, + codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + kwargs..., + ) + a_perm = bipermutedims(a, codomain_perm, domain_perm) + return $f(a_perm, Val(length(codomain_perm)), Val(length(domain_perm)); kwargs...) + end function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...) biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...) return $f(a, blocks(biperm)...; kwargs...) diff --git a/test/test_basics.jl b/test/test_basics.jl index ea7be8d..2b408ae 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -14,6 +14,8 @@ using TensorAlgebra: length_codomain, length_domain, matricize, + bipermutedims, + bipermutedims!, permuteblockeddims, permuteblockeddims!, tuplemortar, @@ -33,15 +35,29 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test length_domain(bt) == 1 end - @testset "permuteblockeddims (eltype=$elt)" for elt in elts - a = randn(elt, 2, 3, 4, 5) - a_perm = permuteblockeddims(a, blockedpermvcat((3, 1), (2, 4))) - @test a_perm == permutedims(a, (3, 1, 2, 4)) - - a = randn(elt, 2, 3, 4, 5) - a_perm = Array{elt}(undef, (4, 2, 3, 5)) - permuteblockeddims!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) - @test a_perm == permutedims(a, (3, 1, 2, 4)) + @testset "bipermutedims/permuteblockeddims (eltype=$elt)" for f in + (:bipermutedims, :permuteblockeddims), + elt in elts + f! = Symbol(f, :!) + @eval begin + a = randn($elt, 2, 3, 4, 5) + a_perm = $f(a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn($elt, 2, 3, 4, 5) + a_perm = $f(a, (3, 1), (2, 4)) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn($elt, 2, 3, 4, 5) + a_perm = Array{$elt}(undef, (4, 2, 3, 5)) + $f!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn($elt, 2, 3, 4, 5) + a_perm = Array{$elt}(undef, (4, 2, 3, 5)) + $f!(a_perm, a, (3, 1), (2, 4)) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + end end @testset "matricize (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 2c00b10..9292f90 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -39,6 +39,9 @@ elts = (Float64, ComplexF64) Q, R = qr(A, (2, 1), (4, 3); full = true) @test A ≈ contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) + + Q, R = qr(A, Val(2), Val(2); full = true) + @test A ≈ contract((:a, :b, :c, :d), Q, (:a, :b, :q), R, (:q, :c, :d)) end @testset "Compact QR ($T)" for T in elts @@ -145,7 +148,7 @@ end labels_Vᴴ = (:d, :c) Acopy = deepcopy(A) - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = true) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = Val(true)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) @@ -153,18 +156,18 @@ end @test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary @test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary - U, S, Vᴴ = svd(A, (2, 1), (4, 3); full = true) + U, S, Vᴴ = svd(A, (2, 1), (4, 3); full = Val(true)) US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) @test A ≈ contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = true) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = Val(true)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) @test A ≈ A′ @test size(Vᴴ, 1) == 1 - U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = true) + U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = Val(true)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (:u,), S, (:u, :v)) A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) @@ -179,7 +182,7 @@ end labels_Vᴴ = (:d, :c) Acopy = deepcopy(A) - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = false) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = Val(false)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_Vᴴ...)) @@ -190,14 +193,14 @@ end Svals = @constinferred svdvals(A, labels_A, labels_U, labels_Vᴴ) @test Svals ≈ diag(S) - U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = false) + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full = Val(false)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) @test A ≈ A′ @test size(U, ndims(U)) == 1 == size(Vᴴ, 1) - U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = false) + U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full = Val(false)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (:u,), S, (:u, :v)) A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) diff --git a/test/test_matrixfunctions.jl b/test/test_matrixfunctions.jl index ff3c288..3e68c4c 100644 --- a/test/test_matrixfunctions.jl +++ b/test/test_matrixfunctions.jl @@ -14,9 +14,12 @@ using Test: @test, @testset TensorAlgebra.$f(a, (3, 2), (4, 1)), TensorAlgebra.$f(a, biperm((3, 2, 4, 1), Val(2))), ) - fa′ = reshape($f(reshape(permutedims(a, (3, 2, 4, 1)), (4, 4))), (2, 2, 2, 2)) + local fa′ = reshape($f(reshape(permutedims(a, (3, 2, 4, 1)), (4, 4))), (2, 2, 2, 2)) @test fa ≈ fa′ end + fa = TensorAlgebra.$f(a, Val(2), Val(2)) + fa′ = reshape($f(reshape(a, (4, 4))), (2, 2, 2, 2)) + @test fa ≈ fa′ end end end