diff --git a/Project.toml b/Project.toml index 9fa42f7..288b1e9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.2.10" +version = "0.3.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -19,7 +19,7 @@ BlockArrays = "1.5.0" EllipsisNotation = "1.8.0" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.1.1" -TensorProducts = "0.1.0" +TensorProducts = "0.1.5" TupleTools = "1.6.0" TypeParameterAccessors = "0.2.1, 0.3" julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index c029cf2..661cea0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [compat] Documenter = "1.8.1" Literate = "2.20.1" -TensorAlgebra = "0.2.0" +TensorAlgebra = "0.3.0" diff --git a/examples/Project.toml b/examples/Project.toml index b12e0b4..0c257a8 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [compat] -TensorAlgebra = "0.2.0" +TensorAlgebra = "0.3.0" diff --git a/src/BaseExtensions/permutedims.jl b/src/BaseExtensions/permutedims.jl index c80e07d..19f8fd7 100644 --- a/src/BaseExtensions/permutedims.jl +++ b/src/BaseExtensions/permutedims.jl @@ -1,5 +1,6 @@ # Workaround for https://github.com/JuliaLang/julia/issues/52615. # Fixed by https://github.com/JuliaLang/julia/pull/52623. +# TODO remove once support for Julia 1.10 is dropped function _permutedims!( a_dest::AbstractArray{<:Any,N}, a_src::AbstractArray{<:Any,N}, perm::Tuple{Vararg{Int,N}} ) where {N} diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 1bc23e4..ee32607 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -22,8 +22,7 @@ include("MatrixAlgebra.jl") include("blockedtuple.jl") include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") -include("fusedims.jl") -include("splitdims.jl") +include("matricize.jl") include("contract/contract.jl") include("contract/output_labels.jl") include("contract/blockedperms.jl") diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 2d29aea..c383587 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -45,6 +45,17 @@ function Base.invperm(bp::AbstractBlockPermutation) return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp))) end +# interface + +# Bipartition a vector according to the +# bipartitioned permutation. +# Like `Base.permute!` block out-of-place and blocked. +function blockpermute(v, blockedperm::AbstractBlockPermutation) + return tuplemortar(map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))) +end + +Base.getindex(v, perm::AbstractBlockPermutation) = blockpermute(v, perm) + # # Constructors # @@ -53,13 +64,6 @@ function blockedperm(bt::AbstractBlockTuple) return permmortar(blocks(bt)) end -# Bipartition a vector according to the -# bipartitioned permutation. -# Like `Base.permute!` block out-of-place and blocked. -function blockpermute(v, blockedperm::AbstractBlockPermutation) - return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm)) -end - # blockedpermvcat((4, 3), (2, 1)) function blockedpermvcat( permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index 610ff87..3fa1c02 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -4,137 +4,28 @@ using Base.PermutedDimsArrays: genperm # i.e. `ContractAdd`? function output_axes( ::typeof(contract), - biperm_dest::BlockedPermutation{2}, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation{2}, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation{2}, + biperm2::AbstractBlockPermutation{2}, α::Number=one(Bool), ) - axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1) - axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2) + axes_codomain, axes_contracted = blocks(axes(a1)[biperm1]) + axes_contracted2, axes_domain = blocks(axes(a2)[biperm2]) @assert axes_contracted == axes_contracted2 return genperm((axes_codomain..., axes_domain...), invperm(Tuple(biperm_dest))) end -# Inner-product contraction. -# TODO: Use `ArrayLayouts`-like `MulAdd` object, -# i.e. `ContractAdd`? -function output_axes( - ::typeof(contract), - perm_dest::BlockedPermutation{0}, - a1::AbstractArray, - perm1::BlockedPermutation{1}, - a2::AbstractArray, - perm2::BlockedPermutation{1}, - α::Number=one(Bool), -) - axes_contracted = blockpermute(axes(a1), perm1) - axes_contracted′ = blockpermute(axes(a2), perm2) - @assert axes_contracted == axes_contracted′ - return () -end - -# Vec-mat. -function output_axes( - ::typeof(contract), - perm_dest::BlockedPermutation{1}, - a1::AbstractArray, - perm1::BlockedPermutation{1}, - a2::AbstractArray, - biperm2::BlockedPermutation{2}, - α::Number=one(Bool), -) - (axes_contracted,) = blockpermute(axes(a1), perm1) - axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2) - @assert axes_contracted == axes_contracted′ - return genperm((axes_dest...,), invperm(Tuple(perm_dest))) -end - -# Mat-vec. -function output_axes( - ::typeof(contract), - perm_dest::BlockedPermutation{1}, - a1::AbstractArray, - perm1::BlockedPermutation{2}, - a2::AbstractArray, - biperm2::BlockedPermutation{1}, - α::Number=one(Bool), -) - axes_dest, axes_contracted = blockpermute(axes(a1), perm1) - (axes_contracted′,) = blockpermute(axes(a2), biperm2) - @assert axes_contracted == axes_contracted′ - return genperm((axes_dest...,), invperm(Tuple(perm_dest))) -end - -# Outer product. -function output_axes( - ::typeof(contract), - biperm_dest::BlockedPermutation{2}, - a1::AbstractArray, - perm1::BlockedPermutation{1}, - a2::AbstractArray, - perm2::BlockedPermutation{1}, - α::Number=one(Bool), -) - @assert istrivialperm(Tuple(perm1)) - @assert istrivialperm(Tuple(perm2)) - axes_dest = (axes(a1)..., axes(a2)...) - return genperm(axes_dest, invperm(Tuple(biperm_dest))) -end - -# Array-scalar contraction. -function output_axes( - ::typeof(contract), - perm_dest::BlockedPermutation{1}, - a1::AbstractArray, - perm1::BlockedPermutation{1}, - a2::AbstractArray, - perm2::BlockedPermutation{0}, - α::Number=one(Bool), -) - @assert istrivialperm(Tuple(perm1)) - axes_dest = axes(a1) - return genperm(axes_dest, invperm(Tuple(perm_dest))) -end - -# Scalar-array contraction. -function output_axes( - ::typeof(contract), - perm_dest::BlockedPermutation{1}, - a1::AbstractArray, - perm1::BlockedPermutation{0}, - a2::AbstractArray, - perm2::BlockedPermutation{1}, - α::Number=one(Bool), -) - @assert istrivialperm(Tuple(perm2)) - axes_dest = axes(a2) - return genperm(axes_dest, invperm(Tuple(perm_dest))) -end - -# Scalar-scalar contraction. -function output_axes( - ::typeof(contract), - perm_dest::BlockedPermutation{0}, - a1::AbstractArray, - perm1::BlockedPermutation{0}, - a2::AbstractArray, - perm2::BlockedPermutation{0}, - α::Number=one(Bool), -) - return () -end - # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( ::typeof(contract), - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation, α::Number=one(Bool), ) axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α) diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index 6f33134..a41033a 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) permblocks_dest = (perm_codomain_dest, perm_domain_dest) - biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...) + biperm_dest = blockedpermvcat(permblocks_dest...) permblocks1 = (perm_codomain1, perm_domain1) - biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...) + biperm1 = blockedpermvcat(permblocks1...) permblocks2 = (perm_codomain2, perm_domain2) - biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...) + biperm2 = blockedpermvcat(permblocks2...) return biperm_dest, biperm1, biperm2 end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index dae9441..e47cc89 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -13,11 +13,11 @@ default_contract_alg() = Matricize() function contract!( alg::Algorithm, a_dest::AbstractArray, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation, α::Number, β::Number, ) @@ -110,11 +110,11 @@ end function contract( alg::Algorithm, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation, α::Number; kwargs..., ) diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 1750ae2..98978bd 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -1,103 +1,20 @@ using LinearAlgebra: mul! function contract!( - alg::Matricize, + ::Matricize, a_dest::AbstractArray, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation{2}, α::Number, β::Number, ) - a_dest_mat = fusedims(a_dest, biperm_dest) - a1_mat = fusedims(a1, biperm1) - a2_mat = fusedims(a2, biperm2) - _mul!(a_dest_mat, a1_mat, a2_mat, α, β) - splitdims!(a_dest, a_dest_mat, biperm_dest) - return a_dest -end - -# Matrix multiplication. -function _mul!( - a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number -) - mul!(a_dest, a1, a2, α, β) - return a_dest -end - -# Inner product. -function _mul!( - a_dest::AbstractArray{<:Any,0}, - a1::AbstractVector, - a2::AbstractVector, - α::Number, - β::Number, -) - a_dest[] = transpose(a1) * a2 * α + a_dest[] * β - return a_dest -end - -# Vec-mat. -function _mul!( - a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number -) - mul!(transpose(a_dest), transpose(a1), a2, α, β) - return a_dest -end - -# Mat-vec. -function _mul!( - a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number -) - mul!(a_dest, a1, a2, α, β) - return a_dest -end - -# Outer product. -function _mul!( - a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number -) - mul!(a_dest, a1, transpose(a2), α, β) - return a_dest -end - -# Array-scalar contraction. -function _mul!( - a_dest::AbstractVector, - a1::AbstractVector, - a2::AbstractArray{<:Any,0}, - α::Number, - β::Number, -) - α′ = a2[] * α - a_dest .= a1 .* α′ .+ a_dest .* β - return a_dest -end - -# Scalar-array contraction. -function _mul!( - a_dest::AbstractVector, - a1::AbstractArray{<:Any,0}, - a2::AbstractVector, - α::Number, - β::Number, -) - # Preserve the ordering in case of non-commutative algebra. - a_dest .= a1[] .* a2 .* α .+ a_dest .* β - return a_dest -end - -# Scalar-scalar contraction. -function _mul!( - a_dest::AbstractArray{<:Any,0}, - a1::AbstractArray{<:Any,0}, - a2::AbstractArray{<:Any,0}, - α::Number, - β::Number, -) - # Preserve the ordering in case of non-commutative algebra. - a_dest[] = a1[] * a2[] * α + a_dest[] * β + a_dest_mat = matricize(a_dest, biperm_dest) + a1_mat = matricize(a1, biperm1) + a2_mat = matricize(a2, biperm2) + mul!(a_dest_mat, a1_mat, a2_mat, α, β) + unmatricize!(a_dest, a_dest_mat, biperm_dest) return a_dest end diff --git a/src/factorizations.jl b/src/factorizations.jl index a2fee2d..6312564 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -9,25 +9,25 @@ for f in ( biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return $f(A, biperm; kwargs...) end - function $f(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) + function $f(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization X, Y = MatrixAlgebra.$f(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_X = (axes_codomain..., axes(X, 2)) - axes_Y = (axes(Y, 1), axes_domain...) - return splitdims(X, axes_X), splitdims(Y, axes_Y) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) + axes_Y = tuplemortar(((axes(Y, 1),), axes_domain)) + return unmatricize(X, axes_X), unmatricize(Y, axes_Y) end end end """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R - qr(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> Q, R + qr(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Q, R Compute the QR decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -45,7 +45,7 @@ qr """ lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q - lq(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> L, Q + lq(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> L, Q Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -63,7 +63,7 @@ lq """ left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P - left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P + left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> W, P Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -79,7 +79,7 @@ left_polar """ right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W - right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W + right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> P, W Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -95,7 +95,7 @@ right_polar """ left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C - left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C + left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> V, C Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -111,7 +111,7 @@ left_orth """ right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V - right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V + right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> C, V Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -127,7 +127,7 @@ right_orth """ factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y - factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y + factorize(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y Compute the decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -143,7 +143,7 @@ factorize """ eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V - eigen(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> D, V + eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D, V Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -163,22 +163,22 @@ function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwarg biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return eigen(A, biperm; kwargs...) end -function eigen(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) +function eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) # matrix to tensor - axes_codomain, = blockpermute(axes(A), biperm) - axes_V = (axes_codomain..., axes(V, ndims(V))) - return D, splitdims(V, axes_V) + axes_codomain, = blocks(axes(A)[biperm]) + axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) + return D, unmatricize(V, axes_V) end """ eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D - eigvals(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> D + eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D Compute the eigenvalues of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -196,14 +196,14 @@ function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwa biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return eigvals(A, biperm; kwargs...) end -function eigvals(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - A_mat = fusedims(A, biperm) +function eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) + A_mat = matricize(A, biperm) return MatrixAlgebra.eigvals!(A_mat; kwargs...) end """ svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ - svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ + svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> U, S, Vᴴ Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -222,23 +222,23 @@ function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs. biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return svd(A, biperm; kwargs...) end -function svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) +function svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_U = (axes_codomain..., axes(U, 2)) - axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...) - return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) + axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) + return unmatricize(U, axes_U), S, unmatricize(Vᴴ, axes_Vᴴ) end """ svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) -> S - svdvals(A::AbstractArray, biperm::BlockedPermutation{2}) -> S + svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> S Compute the singular values of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -250,14 +250,14 @@ function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return svdvals(A, biperm) end -function svdvals(A::AbstractArray, biperm::BlockedPermutation{2}) - A_mat = fusedims(A, biperm) +function svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) + A_mat = matricize(A, biperm) return MatrixAlgebra.svdvals!(A_mat) end """ left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> N - left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> N + left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> N Compute the left nullspace of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -276,18 +276,17 @@ function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; k biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return left_null(A, biperm; kwargs...) end -function left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - A_mat = fusedims(A, biperm) +function left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) + A_mat = matricize(A, biperm) N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) - axes_codomain, _ = blockpermute(axes(A), biperm) - axes_N = (axes_codomain..., axes(N, 2)) - N_tensor = splitdims(N, axes_N) - return N_tensor + axes_codomain = first(blocks(axes(A)[biperm])) + axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) + return unmatricize(N, axes_N) end """ right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Nᴴ - right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> Nᴴ + right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Nᴴ Compute the right nullspace of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -306,10 +305,10 @@ function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return right_null(A, biperm; kwargs...) end -function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) - A_mat = fusedims(A, biperm) +function right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) + A_mat = matricize(A, biperm) Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) - _, axes_domain = blockpermute(axes(A), biperm) - axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...) - return splitdims(Nᴴ, axes_Nᴴ) + axes_domain = last(blocks((axes(A)[biperm]))) + axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) + return unmatricize(Nᴴ, axes_Nᴴ) end diff --git a/src/fusedims.jl b/src/fusedims.jl deleted file mode 100644 index 87bab85..0000000 --- a/src/fusedims.jl +++ /dev/null @@ -1,66 +0,0 @@ -using TensorProducts: ⊗ -using .BaseExtensions: _permutedims, _permutedims! - -abstract type FusionStyle end - -struct ReshapeFusion <: FusionStyle end -struct BlockReshapeFusion <: FusionStyle end -struct SectorFusion <: FusionStyle end - -# Defaults to a simple reshape -combine_fusion_styles(style1::Style, style2::Style) where {Style<:FusionStyle} = Style() -combine_fusion_styles(style1::FusionStyle, style2::FusionStyle) = ReshapeFusion() -combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles) -FusionStyle(axis::AbstractUnitRange) = ReshapeFusion() -function FusionStyle(axes::Tuple{Vararg{AbstractUnitRange}}) - return combine_fusion_styles(FusionStyle.(axes)...) -end -FusionStyle(a::AbstractArray) = FusionStyle(axes(a)) - -# Overload this version for most arrays -function fusedims(::ReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...) - return reshape(a, axes) -end - -# Overload this version for most arrays -function fusedims(a::AbstractArray, ax::AbstractUnitRange, axes::AbstractUnitRange...) - return fusedims(FusionStyle(a), a, ax, axes...) -end - -# Overload this version for fusion tensors, array maps, etc. -function fusedims( - a::AbstractArray, - axb::Tuple{Vararg{AbstractUnitRange}}, - axesblocks::Tuple{Vararg{AbstractUnitRange}}..., -) - return fusedims(a, flatten_tuples((axb, axesblocks...))...) -end - -# Fix ambiguity issue -fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a - -function fusedims(a::AbstractArray, permblocks...) - return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a)))) -end - -function fuseaxes( - axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation -) - axesblocks = blockpermute(axes, blockedperm) - return map(block -> ⊗(block...), axesblocks) -end - -function fuseaxes(a::AbstractArray, blockedperm::AbstractBlockPermutation) - return fuseaxes(axes(a), blockedperm) -end - -# Fuse adjacent dimensions -function fusedims(a::AbstractArray, blockedperm::BlockedTrivialPermutation) - axes_fused = fuseaxes(a, blockedperm) - return fusedims(a, axes_fused) -end - -function fusedims(a::AbstractArray, blockedperm::BlockedPermutation) - a_perm = _permutedims(a, Tuple(blockedperm)) - return fusedims(a_perm, trivialperm(blockedperm)) -end diff --git a/src/matricize.jl b/src/matricize.jl new file mode 100644 index 0000000..49bc949 --- /dev/null +++ b/src/matricize.jl @@ -0,0 +1,151 @@ +using LinearAlgebra: Diagonal + +using BlockArrays: AbstractBlockedUnitRange, blockedrange + +using TensorProducts: ⊗ +using .BaseExtensions: _permutedims, _permutedims! + +# ===================================== FusionStyle ====================================== +abstract type FusionStyle end + +struct ReshapeFusion <: FusionStyle end + +FusionStyle(a::AbstractArray) = FusionStyle(a, axes(a)) +function FusionStyle(a::AbstractArray, t::Tuple{Vararg{AbstractUnitRange}}) + return FusionStyle(a, combine_fusion_styles(FusionStyle.(t)...)) +end + +# Defaults to ReshapeFusion, a simple reshape +FusionStyle(::AbstractUnitRange) = ReshapeFusion() +FusionStyle(::AbstractArray, ::ReshapeFusion) = ReshapeFusion() + +combine_fusion_styles() = ReshapeFusion() +combine_fusion_styles(::Style, ::Style) where {Style<:FusionStyle} = Style() +combine_fusion_styles(::FusionStyle, ::FusionStyle) = ReshapeFusion() +combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles) + +# ======================================= misc ======================================== +trivial_axis(::Tuple{}) = Base.OneTo(1) +trivial_axis(::Tuple{Vararg{AbstractUnitRange}}) = Base.OneTo(1) +trivial_axis(::Tuple{Vararg{AbstractBlockedUnitRange}}) = blockedrange([1]) + +function fuseaxes( + axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation +) + axesblocks = blocks(axes[blockedperm]) + return map(block -> isempty(block) ? trivial_axis(axes) : ⊗(block...), axesblocks) +end + +# TODO remove _permutedims once support for Julia 1.10 is dropped +# define permutedims with a BlockedPermuation. Default is to flatten it. +function Base.permutedims(a::AbstractArray, biperm::AbstractBlockPermutation) + return _permutedims(a, Tuple(biperm)) +end + +# solve ambiguities +function Base.permutedims(a::StridedArray, biperm::AbstractBlockPermutation) + return _permutedims(a, Tuple(biperm)) +end +function Base.permutedims(a::Diagonal, biperm::AbstractBlockPermutation) + return _permutedims(a, Tuple(biperm)) +end + +function Base.permutedims!( + a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation +) + return _permutedims!(a, b, Tuple(biperm)) +end + +# solve ambiguities +function Base.permutedims!( + a::Array{T,N}, b::StridedArray{T,N}, biperm::AbstractBlockPermutation +) where {T,N} + return _permutedims!(a, b, Tuple(biperm)) +end + +# ===================================== matricize ======================================== +# TBD settle copy/not copy convention +# matrix factorizations assume copy +# maybe: copy=false kwarg + +function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2}) + return matricize(FusionStyle(a), a, biperm) +end + +function matricize( + style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + a_perm = permutedims(a, biperm) + return matricize(style, a_perm, trivialperm(biperm)) +end + +function matricize( + style::FusionStyle, a::AbstractArray, biperm::BlockedTrivialPermutation{2} +) + return throw(MethodError(matricize, Tuple{typeof(style),typeof(a),typeof(biperm)})) +end + +# default is reshape +function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}) + new_axes = fuseaxes(axes(a), biperm) + return reshape(a, new_axes...) +end + +function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple) + return matricize(a, blockedpermvcat(permblock1, permblock2; length=Val(ndims(a)))) +end + +# ==================================== unmatricize ======================================= +function unmatricize( + m::AbstractMatrix, + axes::Tuple{Vararg{AbstractUnitRange}}, + biperm::AbstractBlockPermutation{2}, +) + return unmatricize(FusionStyle(m), m, axes, biperm) +end + +function unmatricize( + ::FusionStyle, + m::AbstractMatrix, + axes::Tuple{Vararg{AbstractUnitRange}}, + biperm::AbstractBlockPermutation{2}, +) + blocked_axes = axes[biperm] + a_perm = unmatricize(m, blocked_axes) + return permutedims(a_perm, invperm(biperm)) +end + +function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...) + return reshape(m, axes...) +end + +function unmatricize( + ::ReshapeFusion, + m::AbstractMatrix, + blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}, +) + return reshape(m, Tuple(blocked_axes)...) +end + +function unmatricize( + m::AbstractMatrix, blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} +) + return unmatricize(FusionStyle(m), m, blocked_axes) +end + +function unmatricize( + m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, +) + blocked_axes = tuplemortar((codomain_axes, domain_axes)) + return unmatricize(m, blocked_axes) +end + +function unmatricize!( + a::AbstractArray, m::AbstractMatrix, biperm::AbstractBlockPermutation{2} +) + blocked_axes = axes(a)[biperm] + a_perm = unmatricize(m, blocked_axes) + return permutedims!(a, a_perm, invperm(biperm)) +end diff --git a/src/splitdims.jl b/src/splitdims.jl deleted file mode 100644 index 0554c61..0000000 --- a/src/splitdims.jl +++ /dev/null @@ -1,68 +0,0 @@ -using .BaseExtensions: _permutedims, _permutedims! - -to_axis(a::AbstractUnitRange) = a -to_axis(n::Integer) = Base.OneTo(n) - -function blockedaxes(a::AbstractArray, sizeblocks::Pair...) - axes_a = axes(a) - axes_split = tuple.(axes(a)) - for (dim, sizeblock) in sizeblocks - # TODO: Handle conversion from length to range! - axes_split = Base.setindex(axes_split, to_axis.(sizeblock), dim) - end - return axes_split -end - -# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2) -function splitdims(::ReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...) - # TODO: Add `uncanonicalizedims`. - # TODO: Need `length` since `reshape` doesn't accept `axes`, - # maybe make a `reshape_axes` function. - return reshape(a, length.(axes)...) -end - -# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2) -function splitdims(a::AbstractArray, axes::AbstractUnitRange...) - return splitdims(FusionStyle(a), a, axes...) -end - -# splitdims(randn(4, 4), (1:2, 1:2), (1:2, 1:2)) -function splitdims(a::AbstractArray, axesblocks::Tuple{Vararg{AbstractUnitRange}}...) - # TODO: Add `uncanonicalizedims`. - return splitdims(a, flatten_tuples(axesblocks)...) -end - -# Fix ambiguity issue -splitdims(a::AbstractArray) = a - -# splitdims(randn(4, 4), (2, 2), (2, 2)) -function splitdims(a::AbstractArray, sizeblocks::Tuple{Vararg{Integer}}...) - return splitdims(a, map(x -> Base.OneTo.(x), sizeblocks)...) -end - -# splitdims(randn(4, 4), 2 => (1:2, 1:2)) -function splitdims(a::AbstractArray, sizeblocks::Pair...) - return splitdims(a, blockedaxes(a, sizeblocks...)...) -end - -# TODO: Is this needed? -function splitdims( - a::AbstractArray, - axes_dest::Tuple{Vararg{AbstractUnitRange}}, - blockedperm::BlockedPermutation, -) - # TODO: Pass grouped axes. - a_dest_perm = splitdims(a, axes_dest...) - a_dest = _permutedims(a_dest_perm, invperm(Tuple(blockedperm))) - return a_dest -end - -function splitdims!( - a_dest::AbstractArray, a::AbstractArray, blockedperm::BlockedPermutation -) - axes_dest = map(i -> axes(a_dest, i), Tuple(blockedperm)) - # TODO: Pass grouped axes. - a_dest_perm = splitdims(a, axes_dest...) - _permutedims!(a_dest, a_dest_perm, invperm(Tuple(blockedperm))) - return a_dest -end diff --git a/test/Project.toml b/test/Project.toml index a780342..0223db5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,16 +2,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" -JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" -LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" -SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -21,17 +17,13 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Aqua = "0.8.9" BlockArrays = "1.4.0" EllipsisNotation = "1.8.0" -JLArrays = "0.2.0" -LabelledNumbers = "0.1.1" LinearAlgebra = "<0.0.1, 1" MatrixAlgebraKit = "0.1" -Pkg = "<0.0.1, 1" Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -SymmetrySectors = "0.1" -TensorAlgebra = "0.2.0" +TensorAlgebra = "0.3.0" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 5c6c040..9e2d2ee 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,7 +1,9 @@ -using TensorAlgebra: TensorAlgebra -using Aqua: Aqua using Test: @testset +using Aqua: Aqua + +using TensorAlgebra: TensorAlgebra + @testset "Code quality (Aqua.jl)" begin # TODO: fix and re-enable ambiguity checks Aqua.test_all(TensorAlgebra; ambiguities=false, piracies=false) diff --git a/test/test_basics.jl b/test/test_basics.jl index 3bce1b7..fc11630 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,68 +1,111 @@ +using Test: @test, @test_broken, @test_throws, @testset + using EllipsisNotation: var".." -using LinearAlgebra: norm using StableRNGs: StableRNG -using TensorAlgebra: contract, contract!, fusedims, qr, splitdims, svd using TensorOperations: TensorOperations -using Test: @test, @test_broken, @testset + +using TensorAlgebra: + blockedpermvcat, contract, contract!, matricize, tuplemortar, unmatricize, unmatricize! default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin - @testset "fusedims (eltype=$elt)" for elt in elts + @testset "matricize (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) - a_fused = fusedims(a, (1, 2), (3, 4)) + + a_fused = matricize(a, blockedpermvcat((1, 2), (3, 4))) @test eltype(a_fused) === elt @test a_fused ≈ reshape(a, 6, 20) - a_fused = fusedims(a, (3, 1), (2, 4)) + + a_fused = matricize(a, (1, 2), (3, 4)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(a, 6, 20) + a_fused = matricize(a, (3, 1), (2, 4)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) - a_fused = fusedims(a, (3, 1, 2), 4) + a_fused = matricize(a, (3, 1, 2), (4,)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (24, 5)) - a_fused = fusedims(a, .., (3, 1)) + a_fused = matricize(a, (..,), (3, 1)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (3, 5, 8)) - a_fused = fusedims(a, (3, 1), ..) + @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (15, 8)) + a_fused = matricize(a, (3, 1), (..,)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5)) - a_fused = fusedims(a, .., (3, 1), 2) + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) + + a_fused = matricize(a, (), (..,)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(a, (1, 120)) + a_fused = matricize(a, (..,), ()) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (4, 3, 1, 2)), (5, 8, 3)) - a_fused = fusedims(a, (3, 1), .., 2) + @test a_fused ≈ reshape(a, (120, 1)) + + @test_throws MethodError matricize(a, (1, 2), (3,), (4,)) + @test_throws MethodError matricize(a, (1, 2, 3, 4)) + + v = ones(elt, 2) + a_fused = matricize(v, (1,), ()) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 4, 2)), (8, 5, 3)) - a_fused = fusedims(a, (3, 1), (..,)) + @test a_fused ≈ ones(elt, 2, 1) + a_fused = matricize(v, (), (1,)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) + @test a_fused ≈ ones(elt, 1, 2) + + a_fused = matricize(ones(elt), (), ()) + @test eltype(a_fused) === elt + @test a_fused ≈ ones(elt, 1, 1) end - @testset "splitdims (eltype=$elt)" for elt in elts - a = randn(elt, 6, 20) - a_split = splitdims(a, (2, 3), (5, 4)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = splitdims(a, (1:2, 1:3), (1:5, 1:4)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = splitdims(a, 2 => (5, 4), 1 => (2, 3)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = splitdims(a, 2 => (1:5, 1:4), 1 => (1:2, 1:3)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = splitdims(a, 2 => (5, 4)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (6, 5, 4)) - a_split = splitdims(a, 2 => (1:5, 1:4)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (6, 5, 4)) - a_split = splitdims(a, 1 => (2, 3)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 20)) - a_split = splitdims(a, 1 => (1:2, 1:3)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 20)) + + @testset "unmatricize (eltype=$elt)" for elt in elts + a0 = randn(elt, 2, 3, 4, 5) + axes0 = axes(a0) + m = reshape(a0, 6, 20) + + a = unmatricize(m, tuplemortar((axes0[1:2], axes0[3:4]))) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0[1:2], axes0[3:4]) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0, blockedpermvcat((1, 2), (3, 4))) + @test eltype(a) === elt + @test a ≈ a0 + + bp = blockedpermvcat((4, 2), (1, 3)) + a = unmatricize(m, map(i -> axes0[i], invperm(Tuple(bp))), bp) + @test eltype(a) === elt + @test a ≈ permutedims(a0, invperm(Tuple(bp))) + + a = similar(a0) + unmatricize!(a, m, blockedpermvcat((1, 2), (3, 4))) + @test a ≈ a0 + + m1 = matricize(a0, bp) + a = unmatricize(m1, axes0, bp) + @test a ≈ a0 + + a1 = permutedims(a0, Tuple(bp)) + a = similar(a1) + unmatricize!(a, m, invperm(bp)) + @test a ≈ a1 + + a = unmatricize(m, (), axes0) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0, ()) + @test eltype(a) === elt + @test a ≈ a0 + + m = randn(elt, 1, 1) + a = unmatricize(m, (), ()) + @test a isa Array{elt,0} + @test a[] == m[1, 1] end + using TensorOperations: TensorOperations @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) diff --git a/test/test_blockarrays_contract.jl b/test/test_blockarrays_contract.jl index c3b3af8..f337cc7 100644 --- a/test/test_blockarrays_contract.jl +++ b/test/test_blockarrays_contract.jl @@ -1,8 +1,10 @@ -using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize -using TensorAlgebra: contract using Random: randn! using Test: @test, @test_broken, @testset +using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize + +using TensorAlgebra: contract + function randn_blockdiagonal(elt::Type, axes::Tuple) a = zeros(elt, axes) blockdiaglength = minimum(blocksize(a)) diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index e9198a9..016011f 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -10,6 +10,7 @@ using TensorAlgebra: BlockedTuple, blockedperm, blockedperm_indexin, + blockpermute, blockedtrivialperm, blockedpermvcat, permmortar, @@ -135,6 +136,16 @@ using TensorAlgebra: p = blockedpermvcat((3, 2), (..,), 1) @test p == blockedpermvcat((3, 2), (), (1,)) + + # blockpermute + t = (1, 2, 3, 4) + pblocks = tuplemortar(((4, 3), (), (1, 2))) + p = blockedperm(pblocks) + @test (@constinferred blockpermute(t, p)) isa BlockedTuple{3,(2, 0, 2),NTuple{4,Int64}} + @test blockpermute(t, p) == pblocks + @test t[p] == pblocks + @test pblocks[p] == tuplemortar(((2, 1), (), (4, 3))) + @test p[p] == tuplemortar(((2, 1), (), (4, 3))) end @testset "BlockedTrivialPermutation" begin diff --git a/test/test_exports.jl b/test/test_exports.jl index 8dba1bf..0c7b00b 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -1,5 +1,7 @@ -using TensorAlgebra: TensorAlgebra using Test: @test, @testset + +using TensorAlgebra: TensorAlgebra + @testset "Test exports" begin exports = [ :TensorAlgebra, diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 0c5ccdf..efb81fc 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,7 +1,10 @@ -using Test: @test, @testset, @inferred +using LinearAlgebra: LinearAlgebra, norm, diag +using Test: @test, @testset + using TestExtras: @constinferred + +using MatrixAlgebraKit: truncrank using TensorAlgebra: - TensorAlgebra, contract, eigen, eigvals, @@ -18,8 +21,6 @@ using TensorAlgebra: right_polar, svd, svdvals -using MatrixAlgebraKit: truncrank -using LinearAlgebra: LinearAlgebra, norm, diag elts = (Float64, ComplexF64) @@ -147,6 +148,20 @@ end @test A ≈ A′ @test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary @test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary + + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full=true) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) + @test A ≈ A′ + @test size(Vᴴ, 1) == 1 + + U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full=true) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (:u,), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) + @test A ≈ A′ + @test size(U, 2) == 1 end @testset "Compact SVD ($T)" for T in elts @@ -166,6 +181,20 @@ end Svals = @constinferred svdvals(A, labels_A, labels_U, labels_Vᴴ) @test Svals ≈ diag(S) + + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full=false) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) + @test A ≈ A′ + @test size(U, ndims(U)) == 1 == size(Vᴴ, 1) + + U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full=false) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (:u,), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) + @test A ≈ A′ + @test size(U, 1) == 1 == size(Vᴴ, 1) end @testset "Truncated SVD ($T)" for T in elts diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 7a391c3..7ee3598 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -1,7 +1,8 @@ using LinearAlgebra: I, diag, isposdef -using TensorAlgebra.MatrixAlgebra: MatrixAlgebra using Test: @test, @testset +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra + elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "TensorAlgebra.MatrixAlgebra (elt=$elt)" for elt in elts