From 9cb1e59ba9b15ae9e5615c66b1a365665a89e89a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 3 Apr 2025 10:31:11 -0400 Subject: [PATCH 01/18] WIP reuse BlockPerm{2} --- Project.toml | 2 +- src/contract/allocate_output.jl | 54 ++++++++++----------- src/contract/blockedperms.jl | 6 +-- src/contract/contract.jl | 12 ++--- src/contract/contract_matricize/contract.jl | 8 +-- src/factorizations.jl | 54 +++++++++++---------- 6 files changed, 70 insertions(+), 66 deletions(-) diff --git a/Project.toml b/Project.toml index 5d3d28f..cf887d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.2.7" +version = "0.2.9" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index 610ff87..32226cb 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -4,11 +4,11 @@ using Base.PermutedDimsArrays: genperm # i.e. `ContractAdd`? function output_axes( ::typeof(contract), - biperm_dest::BlockedPermutation{2}, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation{2}, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation{2}, + biperm2::AbstractBlockPermutation{2}, α::Number=one(Bool), ) axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1) @@ -22,11 +22,11 @@ end # i.e. `ContractAdd`? function output_axes( ::typeof(contract), - perm_dest::BlockedPermutation{0}, + perm_dest::AbstractBlockPermutation{0}, a1::AbstractArray, - perm1::BlockedPermutation{1}, + perm1::AbstractBlockPermutation{1}, a2::AbstractArray, - perm2::BlockedPermutation{1}, + perm2::AbstractBlockPermutation{1}, α::Number=one(Bool), ) axes_contracted = blockpermute(axes(a1), perm1) @@ -38,11 +38,11 @@ end # Vec-mat. function output_axes( ::typeof(contract), - perm_dest::BlockedPermutation{1}, + perm_dest::AbstractBlockPermutation{1}, a1::AbstractArray, - perm1::BlockedPermutation{1}, + perm1::AbstractBlockPermutation{1}, a2::AbstractArray, - biperm2::BlockedPermutation{2}, + biperm2::AbstractBlockPermutation{2}, α::Number=one(Bool), ) (axes_contracted,) = blockpermute(axes(a1), perm1) @@ -54,11 +54,11 @@ end # Mat-vec. function output_axes( ::typeof(contract), - perm_dest::BlockedPermutation{1}, + perm_dest::AbstractBlockPermutation{1}, a1::AbstractArray, - perm1::BlockedPermutation{2}, + perm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation{1}, + biperm2::AbstractBlockPermutation{1}, α::Number=one(Bool), ) axes_dest, axes_contracted = blockpermute(axes(a1), perm1) @@ -70,11 +70,11 @@ end # Outer product. function output_axes( ::typeof(contract), - biperm_dest::BlockedPermutation{2}, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - perm1::BlockedPermutation{1}, + perm1::AbstractBlockPermutation{1}, a2::AbstractArray, - perm2::BlockedPermutation{1}, + perm2::AbstractBlockPermutation{1}, α::Number=one(Bool), ) @assert istrivialperm(Tuple(perm1)) @@ -86,11 +86,11 @@ end # Array-scalar contraction. function output_axes( ::typeof(contract), - perm_dest::BlockedPermutation{1}, + perm_dest::AbstractBlockPermutation{1}, a1::AbstractArray, - perm1::BlockedPermutation{1}, + perm1::AbstractBlockPermutation{1}, a2::AbstractArray, - perm2::BlockedPermutation{0}, + perm2::AbstractBlockPermutation{0}, α::Number=one(Bool), ) @assert istrivialperm(Tuple(perm1)) @@ -101,11 +101,11 @@ end # Scalar-array contraction. function output_axes( ::typeof(contract), - perm_dest::BlockedPermutation{1}, + perm_dest::AbstractBlockPermutation{1}, a1::AbstractArray, - perm1::BlockedPermutation{0}, + perm1::AbstractBlockPermutation{0}, a2::AbstractArray, - perm2::BlockedPermutation{1}, + perm2::AbstractBlockPermutation{1}, α::Number=one(Bool), ) @assert istrivialperm(Tuple(perm2)) @@ -116,11 +116,11 @@ end # Scalar-scalar contraction. function output_axes( ::typeof(contract), - perm_dest::BlockedPermutation{0}, + perm_dest::AbstractBlockPermutation{0}, a1::AbstractArray, - perm1::BlockedPermutation{0}, + perm1::AbstractBlockPermutation{0}, a2::AbstractArray, - perm2::BlockedPermutation{0}, + perm2::AbstractBlockPermutation{0}, α::Number=one(Bool), ) return () @@ -130,11 +130,11 @@ end # i.e. `ContractAdd`? function allocate_output( ::typeof(contract), - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation, α::Number=one(Bool), ) axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α) diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index 6f33134..a41033a 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) permblocks_dest = (perm_codomain_dest, perm_domain_dest) - biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...) + biperm_dest = blockedpermvcat(permblocks_dest...) permblocks1 = (perm_codomain1, perm_domain1) - biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...) + biperm1 = blockedpermvcat(permblocks1...) permblocks2 = (perm_codomain2, perm_domain2) - biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...) + biperm2 = blockedpermvcat(permblocks2...) return biperm_dest, biperm1, biperm2 end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index dae9441..e47cc89 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -13,11 +13,11 @@ default_contract_alg() = Matricize() function contract!( alg::Algorithm, a_dest::AbstractArray, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation, α::Number, β::Number, ) @@ -110,11 +110,11 @@ end function contract( alg::Algorithm, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation, α::Number; kwargs..., ) diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 1750ae2..26a5f30 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -1,13 +1,13 @@ using LinearAlgebra: mul! function contract!( - alg::Matricize, + ::Matricize, a_dest::AbstractArray, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation{2}, α::Number, β::Number, ) diff --git a/src/factorizations.jl b/src/factorizations.jl index 4adcadb..02f0e54 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -19,7 +19,7 @@ using LinearAlgebra: LinearAlgebra """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R - qr(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> Q, R + qr(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Q, R Compute the QR decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -37,7 +37,9 @@ function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs.. biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return qr(A, biperm; kwargs...) end -function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...) +function qr( + A::AbstractArray, biperm::AbstractBlockPermutation{2}; full::Bool=false, kwargs... +) # tensor to matrix A_mat = fusedims(A, biperm) @@ -46,14 +48,14 @@ function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, k # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_Q = (axes_codomain..., axes(Q, 2)) - axes_R = (axes(R, 1), axes_domain...) + axes_Q = tuplemortar((axes_codomain, (axes(q_matricized, 2),))) + axes_R = tuplemortar(((axes(r_matricized, 1),), axes_domain)) return splitdims(Q, axes_Q), splitdims(R, axes_R) end """ lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q - lq(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> L, Q + lq(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> L, Q Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -71,7 +73,9 @@ function lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs.. biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return lq(A, biperm; kwargs...) end -function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...) +function lq( + A::AbstractArray, biperm::AbstractBlockPermutation{2}; full::Bool=false, kwargs... +) # tensor to matrix A_mat = fusedims(A, biperm) @@ -80,14 +84,14 @@ function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, k # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_L = (axes_codomain..., axes(L, ndims(L))) - axes_Q = (axes(Q, 1), axes_domain...) + axes_L = tuplemortar((axes_codomain, (axes(L, ndims(L)),))) + axes_Q = tuplemortar(((axes(Q, 1),), axes_domain)) return splitdims(L, axes_L), splitdims(Q, axes_Q) end """ eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V - eigen(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> D, V + eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D, V Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -109,7 +113,7 @@ function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwarg end function eigen( A::AbstractArray, - biperm::BlockedPermutation{2}; + biperm::AbstractBlockPermutation{2}; trunc=nothing, ishermitian=nothing, kwargs..., @@ -128,13 +132,13 @@ function eigen( # matrix to tensor axes_codomain, = blockpermute(axes(A), biperm) - axes_V = (axes_codomain..., axes(V, ndims(V))) + axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) return D, splitdims(V, axes_V) end """ eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D - eigvals(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> D + eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D Compute the eigenvalues of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -153,7 +157,7 @@ function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwa return eigvals(A, biperm; kwargs...) end function eigvals( - A::AbstractArray, biperm::BlockedPermutation{2}; ishermitian=nothing, kwargs... + A::AbstractArray, biperm::AbstractBlockPermutation{2}; ishermitian=nothing, kwargs... ) A_mat = fusedims(A, biperm) ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat) @@ -163,7 +167,7 @@ end # TODO: separate out the algorithm selection step from the implementation """ svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ - svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ + svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> U, S, Vᴴ Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -184,7 +188,7 @@ function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs. end function svd( A::AbstractArray, - biperm::BlockedPermutation{2}; + biperm::AbstractBlockPermutation{2}; full::Bool=false, trunc=nothing, kwargs..., @@ -202,14 +206,14 @@ function svd( # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_U = (axes_codomain..., axes(U, 2)) - axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...) + axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) + axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ) end """ svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) -> S - svdvals(A::AbstractArray, biperm::BlockedPermutation{2}) -> S + svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> S Compute the singular values of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -221,14 +225,14 @@ function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return svdvals(A, biperm) end -function svdvals(A::AbstractArray, biperm::BlockedPermutation{2}) +function svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) A_mat = fusedims(A, biperm) return svd_vals!(A_mat) end """ left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> N - left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> N + left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> N Compute the left nullspace of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -247,18 +251,18 @@ function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; k biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return left_null(A, biperm; kwargs...) end -function left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) +function left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) N = left_null!(A_mat; kwargs...) axes_codomain, _ = blockpermute(axes(A), biperm) - axes_N = (axes_codomain..., axes(N, 2)) + axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) N_tensor = splitdims(N, axes_N) return N_tensor end """ right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Nᴴ - right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> Nᴴ + right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Nᴴ Compute the right nullspace of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -277,10 +281,10 @@ function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return right_null(A, biperm; kwargs...) end -function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) +function right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) Nᴴ = right_null!(A_mat; kwargs...) _, axes_domain = blockpermute(axes(A), biperm) - axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...) + axes_Nᴴ = tuplemortar((axes(Nᴴ, 1), (axes_domain,))) return splitdims(Nᴴ, axes_Nᴴ) end From 86164bd125a660369928a3224d769d62e83b8449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 3 Apr 2025 12:31:32 -0400 Subject: [PATCH 02/18] working matricize --- Project.toml | 2 +- src/BaseExtensions/BaseExtensions.jl | 1 - src/BaseExtensions/permutedims.jl | 20 ----- src/TensorAlgebra.jl | 3 +- src/contract/contract_matricize/contract.jl | 93 ++------------------- test/Project.toml | 2 +- test/test_basics.jl | 60 ++++++++----- 7 files changed, 49 insertions(+), 132 deletions(-) delete mode 100644 src/BaseExtensions/permutedims.jl diff --git a/Project.toml b/Project.toml index cf887d5..7de5232 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.2.9" +version = "0.3.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/BaseExtensions/BaseExtensions.jl b/src/BaseExtensions/BaseExtensions.jl index c994fd8..5109cbb 100644 --- a/src/BaseExtensions/BaseExtensions.jl +++ b/src/BaseExtensions/BaseExtensions.jl @@ -1,4 +1,3 @@ module BaseExtensions include("indexin.jl") -include("permutedims.jl") end diff --git a/src/BaseExtensions/permutedims.jl b/src/BaseExtensions/permutedims.jl deleted file mode 100644 index c80e07d..0000000 --- a/src/BaseExtensions/permutedims.jl +++ /dev/null @@ -1,20 +0,0 @@ -# Workaround for https://github.com/JuliaLang/julia/issues/52615. -# Fixed by https://github.com/JuliaLang/julia/pull/52623. -function _permutedims!( - a_dest::AbstractArray{<:Any,N}, a_src::AbstractArray{<:Any,N}, perm::Tuple{Vararg{Int,N}} -) where {N} - permutedims!(a_dest, a_src, perm) - return a_dest -end -function _permutedims!( - a_dest::AbstractArray{<:Any,0}, a_src::AbstractArray{<:Any,0}, perm::Tuple{} -) - a_dest[] = a_src[] - return a_dest -end -function _permutedims(a::AbstractArray{<:Any,N}, perm::Tuple{Vararg{Int,N}}) where {N} - return permutedims(a, perm) -end -function _permutedims(a::AbstractArray{<:Any,0}, perm::Tuple{}) - return copy(a) -end diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 591133b..3037512 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -5,8 +5,7 @@ export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, include("blockedtuple.jl") include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") -include("fusedims.jl") -include("splitdims.jl") +include("matricize.jl") include("contract/contract.jl") include("contract/output_labels.jl") include("contract/blockedperms.jl") diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 26a5f30..98978bd 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -11,93 +11,10 @@ function contract!( α::Number, β::Number, ) - a_dest_mat = fusedims(a_dest, biperm_dest) - a1_mat = fusedims(a1, biperm1) - a2_mat = fusedims(a2, biperm2) - _mul!(a_dest_mat, a1_mat, a2_mat, α, β) - splitdims!(a_dest, a_dest_mat, biperm_dest) - return a_dest -end - -# Matrix multiplication. -function _mul!( - a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number -) - mul!(a_dest, a1, a2, α, β) - return a_dest -end - -# Inner product. -function _mul!( - a_dest::AbstractArray{<:Any,0}, - a1::AbstractVector, - a2::AbstractVector, - α::Number, - β::Number, -) - a_dest[] = transpose(a1) * a2 * α + a_dest[] * β - return a_dest -end - -# Vec-mat. -function _mul!( - a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number -) - mul!(transpose(a_dest), transpose(a1), a2, α, β) - return a_dest -end - -# Mat-vec. -function _mul!( - a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number -) - mul!(a_dest, a1, a2, α, β) - return a_dest -end - -# Outer product. -function _mul!( - a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number -) - mul!(a_dest, a1, transpose(a2), α, β) - return a_dest -end - -# Array-scalar contraction. -function _mul!( - a_dest::AbstractVector, - a1::AbstractVector, - a2::AbstractArray{<:Any,0}, - α::Number, - β::Number, -) - α′ = a2[] * α - a_dest .= a1 .* α′ .+ a_dest .* β - return a_dest -end - -# Scalar-array contraction. -function _mul!( - a_dest::AbstractVector, - a1::AbstractArray{<:Any,0}, - a2::AbstractVector, - α::Number, - β::Number, -) - # Preserve the ordering in case of non-commutative algebra. - a_dest .= a1[] .* a2 .* α .+ a_dest .* β - return a_dest -end - -# Scalar-scalar contraction. -function _mul!( - a_dest::AbstractArray{<:Any,0}, - a1::AbstractArray{<:Any,0}, - a2::AbstractArray{<:Any,0}, - α::Number, - β::Number, -) - # Preserve the ordering in case of non-commutative algebra. - a_dest[] = a1[] * a2[] * α + a_dest[] * β + a_dest_mat = matricize(a_dest, biperm_dest) + a1_mat = matricize(a1, biperm1) + a2_mat = matricize(a2, biperm2) + mul!(a_dest_mat, a1_mat, a2_mat, α, β) + unmatricize!(a_dest, a_dest_mat, biperm_dest) return a_dest end diff --git a/test/Project.toml b/test/Project.toml index a780342..c2e3656 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" SymmetrySectors = "0.1" -TensorAlgebra = "0.2.0" +TensorAlgebra = "0.3.0" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" diff --git a/test/test_basics.jl b/test/test_basics.jl index 3bce1b7..2fba24f 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,47 +1,69 @@ using EllipsisNotation: var".." using LinearAlgebra: norm using StableRNGs: StableRNG -using TensorAlgebra: contract, contract!, fusedims, qr, splitdims, svd +using TensorAlgebra: contract, contract!, matricize, qr, svd, tuplemortar, unmatricize using TensorOperations: TensorOperations -using Test: @test, @test_broken, @testset +using Test: @test, @test_broken, @test_throws, @testset default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin - @testset "fusedims (eltype=$elt)" for elt in elts + @testset "matricize (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) - a_fused = fusedims(a, (1, 2), (3, 4)) + + a_fused = matricize(a, blockedpermvcat((1, 2), (3, 4))) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(a, 6, 20) + + a_fused = matricize(a, tuplemortar(((1, 2), (3, 4)))) @test eltype(a_fused) === elt @test a_fused ≈ reshape(a, 6, 20) - a_fused = fusedims(a, (3, 1), (2, 4)) + + a_fused = matricize(a, (1, 2), (3, 4)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(a, 6, 20) + a_fused = matricize(a, (3, 1), (2, 4)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) - a_fused = fusedims(a, (3, 1, 2), 4) + a_fused = matricize(a, (3, 1, 2), 4) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (24, 5)) - a_fused = fusedims(a, .., (3, 1)) + a_fused = matricize(a, (..,), (3, 1)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (3, 5, 8)) - a_fused = fusedims(a, (3, 1), ..) + @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (15, 8)) + a_fused = matricize(a, (3, 1), (..,)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5)) - a_fused = fusedims(a, .., (3, 1), 2) + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) + + a_fused = matricize(a, (), (..,)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (4, 3, 1, 2)), (5, 8, 3)) - a_fused = fusedims(a, (3, 1), .., 2) + @test a_fused ≈ reshape(a, (1, 120)) + a_fused = matricize(a, (..,), ()) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 4, 2)), (8, 5, 3)) - a_fused = fusedims(a, (3, 1), (..,)) + @test a_fused ≈ reshape(a, (120, 1)) + + @test_throws ArgumentError matricize(a, (1, 2), (3,), (4,)) + @test_throws ArgumentError matricize(a, (1, 2, 3, 4)) + + v = ones(elt, 2) + a_fused = matricize(v, (1,), ()) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) + @test a_fused ≈ ones(elt, 2, 1) + a_fused = matricize(v, (), (1,)) + @test eltype(a_fused) === elt + @test a_fused ≈ ones(elt, 1, 2) + + a_fused = matricize(ones(elt), (), ()) + @test eltype(a_fused) === elt + @test a_fused ≈ ones(elt, 1, 1) end - @testset "splitdims (eltype=$elt)" for elt in elts + @testset "unmatricize (eltype=$elt)" for elt in elts a = randn(elt, 6, 20) - a_split = splitdims(a, (2, 3), (5, 4)) + a_split = unmatricize(a, (2, 3), (5, 4)) @test eltype(a_split) === elt @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = splitdims(a, (1:2, 1:3), (1:5, 1:4)) + a_split = unmatricize(a, (1:2, 1:3), (1:5, 1:4)) @test eltype(a_split) === elt @test a_split ≈ reshape(a, (2, 3, 5, 4)) a_split = splitdims(a, 2 => (5, 4), 1 => (2, 3)) From 42b5fe0f40620b27e64b5207177880ae1334bc70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 3 Apr 2025 15:42:01 -0400 Subject: [PATCH 03/18] WIP, check invperm --- src/fusedims.jl | 66 ------------------------ src/matricize.jl | 122 ++++++++++++++++++++++++++++++++++++++++++++ src/splitdims.jl | 68 ------------------------ test/test_basics.jl | 63 +++++++++++++---------- 4 files changed, 159 insertions(+), 160 deletions(-) delete mode 100644 src/fusedims.jl create mode 100644 src/matricize.jl delete mode 100644 src/splitdims.jl diff --git a/src/fusedims.jl b/src/fusedims.jl deleted file mode 100644 index 87bab85..0000000 --- a/src/fusedims.jl +++ /dev/null @@ -1,66 +0,0 @@ -using TensorProducts: ⊗ -using .BaseExtensions: _permutedims, _permutedims! - -abstract type FusionStyle end - -struct ReshapeFusion <: FusionStyle end -struct BlockReshapeFusion <: FusionStyle end -struct SectorFusion <: FusionStyle end - -# Defaults to a simple reshape -combine_fusion_styles(style1::Style, style2::Style) where {Style<:FusionStyle} = Style() -combine_fusion_styles(style1::FusionStyle, style2::FusionStyle) = ReshapeFusion() -combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles) -FusionStyle(axis::AbstractUnitRange) = ReshapeFusion() -function FusionStyle(axes::Tuple{Vararg{AbstractUnitRange}}) - return combine_fusion_styles(FusionStyle.(axes)...) -end -FusionStyle(a::AbstractArray) = FusionStyle(axes(a)) - -# Overload this version for most arrays -function fusedims(::ReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...) - return reshape(a, axes) -end - -# Overload this version for most arrays -function fusedims(a::AbstractArray, ax::AbstractUnitRange, axes::AbstractUnitRange...) - return fusedims(FusionStyle(a), a, ax, axes...) -end - -# Overload this version for fusion tensors, array maps, etc. -function fusedims( - a::AbstractArray, - axb::Tuple{Vararg{AbstractUnitRange}}, - axesblocks::Tuple{Vararg{AbstractUnitRange}}..., -) - return fusedims(a, flatten_tuples((axb, axesblocks...))...) -end - -# Fix ambiguity issue -fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a - -function fusedims(a::AbstractArray, permblocks...) - return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a)))) -end - -function fuseaxes( - axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation -) - axesblocks = blockpermute(axes, blockedperm) - return map(block -> ⊗(block...), axesblocks) -end - -function fuseaxes(a::AbstractArray, blockedperm::AbstractBlockPermutation) - return fuseaxes(axes(a), blockedperm) -end - -# Fuse adjacent dimensions -function fusedims(a::AbstractArray, blockedperm::BlockedTrivialPermutation) - axes_fused = fuseaxes(a, blockedperm) - return fusedims(a, axes_fused) -end - -function fusedims(a::AbstractArray, blockedperm::BlockedPermutation) - a_perm = _permutedims(a, Tuple(blockedperm)) - return fusedims(a_perm, trivialperm(blockedperm)) -end diff --git a/src/matricize.jl b/src/matricize.jl new file mode 100644 index 0000000..9fdbf8b --- /dev/null +++ b/src/matricize.jl @@ -0,0 +1,122 @@ +using TensorProducts: ⊗ + +# ===================================== FusionStyle ====================================== +abstract type FusionStyle end + +struct ReshapeFusion <: FusionStyle end + +FusionStyle(a::AbstractArray) = FusionStyle(a, axes(a)) +function FusionStyle(a::AbstractArray, t::Tuple{Vararg{AbstractUnitRange}}) + return FusionStyle(a, combine_fusion_styles(FusionStyle.(t)...)) +end + +# Defaults to ReshapeFusion, a simple reshape +FusionStyle(::AbstractArray{<:Any,0}) = ReshapeFusion() # TBD better solution? +FusionStyle(::AbstractUnitRange) = ReshapeFusion() +FusionStyle(::AbstractArray, ::ReshapeFusion) = ReshapeFusion() + +combine_fusion_styles(::Style, ::Style) where {Style<:FusionStyle} = Style() +combine_fusion_styles(::FusionStyle, ::FusionStyle) = ReshapeFusion() +combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles) + +# ======================================= misc ======================================== +function fuseaxes( + axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation +) + axesblocks = blockpermute(axes, blockedperm) + return map(block -> ⊗(block...), axesblocks) +end + +Base.permutedims(a::AbstractArray, bp::AbstractBlockPermutation) = permutedims(a, Tuple(bp)) +Base.permutedims(a::StridedArray, bp::AbstractBlockPermutation) = permutedims(a, Tuple(bp)) + +function Base.permutedims!(a::AbstractArray, b::AbstractArray, bp::AbstractBlockPermutation) + return permutedims!(a, b, Tuple(bp)) +end +function Base.permutedims!( + a::Array{T,N}, b::StridedArray{T,N}, bp::AbstractBlockPermutation +) where {T,N} + return permutedims!(a, b, Tuple(bp)) +end + +# ===================================== matricize ======================================== +# TBD settle copy/not copy convention +# matrix factorizations assume copy +# maybe: copy=false kwarg + +# default is reshape +function matricize( + ::ReshapeFusion, + a::AbstractArray, + row_axis::AbstractUnitRange, + col_axis::AbstractUnitRange, +) + return reshape(a, row_axis, col_axis) +end + +function matricize(::ReshapeFusion, a::AbstractArray, bp::AbstractBlockPermutation{2}) + axes_fused = fuseaxes(axes(a), bp) + return matricize(ReshapeFusion(), a, axes_fused...) +end + +function matricize(a::AbstractArray, tp::BlockedTrivialPermutation{2}) + return matricize(FusionStyle(a), a, tp) +end + +function matricize(a::AbstractArray, bp::AbstractBlockPermutation{2}) + a_perm = permutedims(a, bp) # includes copy + return matricize(a_perm, trivialperm(bp)) +end + +function matricize(a::AbstractArray, bt::AbstractBlockTuple{2}) + return matricize(a, blockedperm(bt)) +end + +function matricize(::AbstractArray, ::AbstractBlockTuple) + throw(ArgumentError("Invalid axis permutation")) +end + +function matricize(a::AbstractArray, permblocks...) + return matricize(a, blockedpermvcat(permblocks...; length=Val(ndims(a)))) +end + +# ==================================== unmatricize ======================================= +function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...) + return reshape(m, Base.to_shape.(axes)...) +end + +function unmatricize( + ::ReshapeFusion, + m::AbstractMatrix, + blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}, +) + return unmatricize(ReshapeFusion(), m, blocked_axes...) +end + +function unmatricize( + m::AbstractMatrix, blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} +) + return unmatricize(FusionStyle(m), m, blocked_axes) +end + +function unmatricize( + m::AbstractMatrix, axes::Tuple{Vararg{AbstractUnitRange}}, bp::AbstractBlockPermutation{2} +) + blocked_axes = tuplemortar(blockpermute(axes, bp)) + a_perm = unmatricize(m, blocked_axes) + return permutedims(a_perm, invperm(bp)) +end + +function unmatricize( + m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, +) + blocked_axes = tuplemortar((codomain_axes, domain_axes)) + return unmatricize(m, blocked_axes) +end + +function unmatricize!(a::AbstractArray, m::AbstractMatrix, bp::AbstractBlockPermutation{2}) + a_perm = unmatricize(m, axes(a), bp) + return permutedims!(a, a_perm, invperm(bp)) +end diff --git a/src/splitdims.jl b/src/splitdims.jl deleted file mode 100644 index 0554c61..0000000 --- a/src/splitdims.jl +++ /dev/null @@ -1,68 +0,0 @@ -using .BaseExtensions: _permutedims, _permutedims! - -to_axis(a::AbstractUnitRange) = a -to_axis(n::Integer) = Base.OneTo(n) - -function blockedaxes(a::AbstractArray, sizeblocks::Pair...) - axes_a = axes(a) - axes_split = tuple.(axes(a)) - for (dim, sizeblock) in sizeblocks - # TODO: Handle conversion from length to range! - axes_split = Base.setindex(axes_split, to_axis.(sizeblock), dim) - end - return axes_split -end - -# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2) -function splitdims(::ReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...) - # TODO: Add `uncanonicalizedims`. - # TODO: Need `length` since `reshape` doesn't accept `axes`, - # maybe make a `reshape_axes` function. - return reshape(a, length.(axes)...) -end - -# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2) -function splitdims(a::AbstractArray, axes::AbstractUnitRange...) - return splitdims(FusionStyle(a), a, axes...) -end - -# splitdims(randn(4, 4), (1:2, 1:2), (1:2, 1:2)) -function splitdims(a::AbstractArray, axesblocks::Tuple{Vararg{AbstractUnitRange}}...) - # TODO: Add `uncanonicalizedims`. - return splitdims(a, flatten_tuples(axesblocks)...) -end - -# Fix ambiguity issue -splitdims(a::AbstractArray) = a - -# splitdims(randn(4, 4), (2, 2), (2, 2)) -function splitdims(a::AbstractArray, sizeblocks::Tuple{Vararg{Integer}}...) - return splitdims(a, map(x -> Base.OneTo.(x), sizeblocks)...) -end - -# splitdims(randn(4, 4), 2 => (1:2, 1:2)) -function splitdims(a::AbstractArray, sizeblocks::Pair...) - return splitdims(a, blockedaxes(a, sizeblocks...)...) -end - -# TODO: Is this needed? -function splitdims( - a::AbstractArray, - axes_dest::Tuple{Vararg{AbstractUnitRange}}, - blockedperm::BlockedPermutation, -) - # TODO: Pass grouped axes. - a_dest_perm = splitdims(a, axes_dest...) - a_dest = _permutedims(a_dest_perm, invperm(Tuple(blockedperm))) - return a_dest -end - -function splitdims!( - a_dest::AbstractArray, a::AbstractArray, blockedperm::BlockedPermutation -) - axes_dest = map(i -> axes(a_dest, i), Tuple(blockedperm)) - # TODO: Pass grouped axes. - a_dest_perm = splitdims(a, axes_dest...) - _permutedims!(a_dest, a_dest_perm, invperm(Tuple(blockedperm))) - return a_dest -end diff --git a/test/test_basics.jl b/test/test_basics.jl index 2fba24f..38ae246 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,8 @@ using EllipsisNotation: var".." using LinearAlgebra: norm using StableRNGs: StableRNG -using TensorAlgebra: contract, contract!, matricize, qr, svd, tuplemortar, unmatricize +using TensorAlgebra: + blockedpermvcat, contract, contract!, matricize, qr, svd, tuplemortar, unmatricize using TensorOperations: TensorOperations using Test: @test, @test_broken, @test_throws, @testset @@ -58,33 +59,43 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test eltype(a_fused) === elt @test a_fused ≈ ones(elt, 1, 1) end + @testset "unmatricize (eltype=$elt)" for elt in elts - a = randn(elt, 6, 20) - a_split = unmatricize(a, (2, 3), (5, 4)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = unmatricize(a, (1:2, 1:3), (1:5, 1:4)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = splitdims(a, 2 => (5, 4), 1 => (2, 3)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = splitdims(a, 2 => (1:5, 1:4), 1 => (1:2, 1:3)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 5, 4)) - a_split = splitdims(a, 2 => (5, 4)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (6, 5, 4)) - a_split = splitdims(a, 2 => (1:5, 1:4)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (6, 5, 4)) - a_split = splitdims(a, 1 => (2, 3)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 20)) - a_split = splitdims(a, 1 => (1:2, 1:3)) - @test eltype(a_split) === elt - @test a_split ≈ reshape(a, (2, 3, 20)) + a0 = randn(elt, 2, 3, 4, 5) + axes0 = axes(a0) + m = reshape(a0, 6, 20) + + a = unmatricize(m, tuplemortar((axes0[1:2], axes0[3:4]))) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0[1:2], axes0[3:4]) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0, blockedpermvcat((1, 2), (3, 4))) + @test eltype(a) === elt + @test a ≈ a0 + + bp = blockedpermvcat((4, 2), (1, 3)) + a = unmatricize(m, map(i -> axes0[i], invperm(Tuple(bp))), bp) + @test eltype(a) === elt + @test a ≈ permutedims(a0, invperm(Tuple(bp))) + + a = unmatricize(m, (), axes0) + @test eltype(a) === elt + @test a ≈ a0 + + a = unmatricize(m, axes0, ()) + @test eltype(a) === elt + @test a ≈ a0 + + m = randn(elt, 1, 1) + a = unmatricize(m, (), ()) + @test a isa Array{elt,0} + @test a[] == m[1, 1] end + using TensorOperations: TensorOperations @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) From dae28076cdbf9c18f5d9316a0ec74a7c3a6ed53d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 3 Apr 2025 18:17:45 -0400 Subject: [PATCH 04/18] passing tests --- src/matricize.jl | 19 ++++++++++--------- test/test_basics.jl | 23 ++++++++++++++++++++++- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index 9fdbf8b..fc06958 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -99,14 +99,6 @@ function unmatricize( return unmatricize(FusionStyle(m), m, blocked_axes) end -function unmatricize( - m::AbstractMatrix, axes::Tuple{Vararg{AbstractUnitRange}}, bp::AbstractBlockPermutation{2} -) - blocked_axes = tuplemortar(blockpermute(axes, bp)) - a_perm = unmatricize(m, blocked_axes) - return permutedims(a_perm, invperm(bp)) -end - function unmatricize( m::AbstractMatrix, codomain_axes::Tuple{Vararg{AbstractUnitRange}}, @@ -116,7 +108,16 @@ function unmatricize( return unmatricize(m, blocked_axes) end +function unmatricize( + m::AbstractMatrix, axes::Tuple{Vararg{AbstractUnitRange}}, bp::AbstractBlockPermutation{2} +) + blocked_axes = tuplemortar(blockpermute(axes, bp)) + a_perm = unmatricize(m, blocked_axes) + return permutedims(a_perm, invperm(bp)) +end + function unmatricize!(a::AbstractArray, m::AbstractMatrix, bp::AbstractBlockPermutation{2}) - a_perm = unmatricize(m, axes(a), bp) + blocked_axes = tuplemortar(blockpermute(axes(a), bp)) + a_perm = unmatricize(m, blocked_axes) return permutedims!(a, a_perm, invperm(bp)) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 38ae246..70cd565 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -2,7 +2,15 @@ using EllipsisNotation: var".." using LinearAlgebra: norm using StableRNGs: StableRNG using TensorAlgebra: - blockedpermvcat, contract, contract!, matricize, qr, svd, tuplemortar, unmatricize + blockedpermvcat, + contract, + contract!, + matricize, + qr, + svd, + tuplemortar, + unmatricize, + unmatricize! using TensorOperations: TensorOperations using Test: @test, @test_broken, @test_throws, @testset @@ -82,6 +90,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test eltype(a) === elt @test a ≈ permutedims(a0, invperm(Tuple(bp))) + a = similar(a0) + unmatricize!(a, m, blockedpermvcat((1, 2), (3, 4))) + @test a ≈ a0 + + m1 = matricize(a0, bp) + a = unmatricize(m1, axes0, bp) + @test a ≈ a0 + + a1 = permutedims(a0, Tuple(bp)) + a = similar(a1) + unmatricize!(a, m, invperm(bp)) + @test a ≈ a1 + a = unmatricize(m, (), axes0) @test eltype(a) === elt @test a ≈ a0 From cfa8779e8fdf5d0d022e0b53f881c6f6fada89c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 3 Apr 2025 18:29:34 -0400 Subject: [PATCH 05/18] WIP factorizations --- src/factorizations.jl | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index 02f0e54..a50b08b 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -41,16 +41,16 @@ function qr( A::AbstractArray, biperm::AbstractBlockPermutation{2}; full::Bool=false, kwargs... ) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization Q, R = full ? qr_full!(A_mat; kwargs...) : qr_compact!(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_Q = tuplemortar((axes_codomain, (axes(q_matricized, 2),))) - axes_R = tuplemortar(((axes(r_matricized, 1),), axes_domain)) - return splitdims(Q, axes_Q), splitdims(R, axes_R) + axes_Q = tuplemortar((axes_codomain, (axes(Q, 2),))) + axes_R = tuplemortar(((axes(R, 1),), axes_domain)) + return unmatricize(Q, axes_Q), unmatricize(R, axes_R) end """ @@ -77,7 +77,7 @@ function lq( A::AbstractArray, biperm::AbstractBlockPermutation{2}; full::Bool=false, kwargs... ) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization L, Q = full ? lq_full!(A_mat; kwargs...) : lq_compact!(A_mat; kwargs...) @@ -86,7 +86,7 @@ function lq( axes_codomain, axes_domain = blockpermute(axes(A), biperm) axes_L = tuplemortar((axes_codomain, (axes(L, ndims(L)),))) axes_Q = tuplemortar(((axes(Q, 1),), axes_domain)) - return splitdims(L, axes_L), splitdims(Q, axes_Q) + return unmatricize(L, axes_L), unmatricize(Q, axes_Q) end """ @@ -119,7 +119,7 @@ function eigen( kwargs..., ) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat) @@ -133,7 +133,7 @@ function eigen( # matrix to tensor axes_codomain, = blockpermute(axes(A), biperm) axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) - return D, splitdims(V, axes_V) + return D, unmatricize(V, axes_V) end """ @@ -159,7 +159,7 @@ end function eigvals( A::AbstractArray, biperm::AbstractBlockPermutation{2}; ishermitian=nothing, kwargs... ) - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat) return (ishermitian ? eigh_vals! : eig_vals!)(A_mat; kwargs...) end @@ -194,7 +194,7 @@ function svd( kwargs..., ) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization if !isnothing(trunc) @@ -208,7 +208,7 @@ function svd( axes_codomain, axes_domain = blockpermute(axes(A), biperm) axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) - return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ) + return unmatricize(U, axes_U), S, unmatricize(Vᴴ, axes_Vᴴ) end """ @@ -226,7 +226,7 @@ function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) return svdvals(A, biperm) end function svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) return svd_vals!(A_mat) end @@ -252,11 +252,11 @@ function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; k return left_null(A, biperm; kwargs...) end function left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) N = left_null!(A_mat; kwargs...) axes_codomain, _ = blockpermute(axes(A), biperm) axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) - N_tensor = splitdims(N, axes_N) + N_tensor = unmatricize(N, axes_N) return N_tensor end @@ -282,9 +282,9 @@ function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; return right_null(A, biperm; kwargs...) end function right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) Nᴴ = right_null!(A_mat; kwargs...) _, axes_domain = blockpermute(axes(A), biperm) axes_Nᴴ = tuplemortar((axes(Nᴴ, 1), (axes_domain,))) - return splitdims(Nᴴ, axes_Nᴴ) + return unmatricize(Nᴴ, axes_Nᴴ) end From bf468386f5c63e87b747cd959c3a23269a7c9ed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 4 Apr 2025 10:42:28 -0400 Subject: [PATCH 06/18] fix factorize --- src/factorizations.jl | 2 +- src/matricize.jl | 54 +++++++++++++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index a50b08b..9a56dd2 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -285,6 +285,6 @@ function right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwarg A_mat = matricize(A, biperm) Nᴴ = right_null!(A_mat; kwargs...) _, axes_domain = blockpermute(axes(A), biperm) - axes_Nᴴ = tuplemortar((axes(Nᴴ, 1), (axes_domain,))) + axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) return unmatricize(Nᴴ, axes_Nᴴ) end diff --git a/src/matricize.jl b/src/matricize.jl index fc06958..94ddb5a 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -1,3 +1,5 @@ +using LinearAlgebra: Diagonal + using TensorProducts: ⊗ # ===================================== FusionStyle ====================================== @@ -27,16 +29,30 @@ function fuseaxes( return map(block -> ⊗(block...), axesblocks) end -Base.permutedims(a::AbstractArray, bp::AbstractBlockPermutation) = permutedims(a, Tuple(bp)) -Base.permutedims(a::StridedArray, bp::AbstractBlockPermutation) = permutedims(a, Tuple(bp)) +# define permutedims with a BlockedPermuation. Default is to flatten it. +function Base.permutedims(a::AbstractArray, biperm::AbstractBlockPermutation) + return permutedims(a, Tuple(biperm)) +end + +# solve ambiguities +function Base.permutedims(a::StridedArray, biperm::AbstractBlockPermutation) + return permutedims(a, Tuple(biperm)) +end +function Base.permutedims(a::Diagonal, biperm::AbstractBlockPermutation) + return permutedims(a, Tuple(biperm)) +end -function Base.permutedims!(a::AbstractArray, b::AbstractArray, bp::AbstractBlockPermutation) - return permutedims!(a, b, Tuple(bp)) +function Base.permutedims!( + a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation +) + return permutedims!(a, b, Tuple(biperm)) end + +# solve ambiguities function Base.permutedims!( - a::Array{T,N}, b::StridedArray{T,N}, bp::AbstractBlockPermutation + a::Array{T,N}, b::StridedArray{T,N}, biperm::AbstractBlockPermutation ) where {T,N} - return permutedims!(a, b, Tuple(bp)) + return permutedims!(a, b, Tuple(biperm)) end # ===================================== matricize ======================================== @@ -54,8 +70,8 @@ function matricize( return reshape(a, row_axis, col_axis) end -function matricize(::ReshapeFusion, a::AbstractArray, bp::AbstractBlockPermutation{2}) - axes_fused = fuseaxes(axes(a), bp) +function matricize(::ReshapeFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}) + axes_fused = fuseaxes(axes(a), biperm) return matricize(ReshapeFusion(), a, axes_fused...) end @@ -63,9 +79,9 @@ function matricize(a::AbstractArray, tp::BlockedTrivialPermutation{2}) return matricize(FusionStyle(a), a, tp) end -function matricize(a::AbstractArray, bp::AbstractBlockPermutation{2}) - a_perm = permutedims(a, bp) # includes copy - return matricize(a_perm, trivialperm(bp)) +function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2}) + a_perm = permutedims(a, biperm) # includes copy + return matricize(a_perm, trivialperm(biperm)) end function matricize(a::AbstractArray, bt::AbstractBlockTuple{2}) @@ -109,15 +125,19 @@ function unmatricize( end function unmatricize( - m::AbstractMatrix, axes::Tuple{Vararg{AbstractUnitRange}}, bp::AbstractBlockPermutation{2} + m::AbstractMatrix, + axes::Tuple{Vararg{AbstractUnitRange}}, + biperm::AbstractBlockPermutation{2}, ) - blocked_axes = tuplemortar(blockpermute(axes, bp)) + blocked_axes = tuplemortar(blockpermute(axes, biperm)) a_perm = unmatricize(m, blocked_axes) - return permutedims(a_perm, invperm(bp)) + return permutedims(a_perm, invperm(biperm)) end -function unmatricize!(a::AbstractArray, m::AbstractMatrix, bp::AbstractBlockPermutation{2}) - blocked_axes = tuplemortar(blockpermute(axes(a), bp)) +function unmatricize!( + a::AbstractArray, m::AbstractMatrix, biperm::AbstractBlockPermutation{2} +) + blocked_axes = tuplemortar(blockpermute(axes(a), biperm)) a_perm = unmatricize(m, blocked_axes) - return permutedims!(a, a_perm, invperm(bp)) + return permutedims!(a, a_perm, invperm(biperm)) end From 60605fb2abcf79eb63c87589eba14a154412183b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 4 Apr 2025 11:23:08 -0400 Subject: [PATCH 07/18] define getindex(::BlockedPermutation --- src/blockedpermutation.jl | 18 +++++++++++------- src/matricize.jl | 6 +++--- test/test_blockedpermutation.jl | 8 ++++++++ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 2d29aea..c383587 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -45,6 +45,17 @@ function Base.invperm(bp::AbstractBlockPermutation) return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp))) end +# interface + +# Bipartition a vector according to the +# bipartitioned permutation. +# Like `Base.permute!` block out-of-place and blocked. +function blockpermute(v, blockedperm::AbstractBlockPermutation) + return tuplemortar(map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))) +end + +Base.getindex(v, perm::AbstractBlockPermutation) = blockpermute(v, perm) + # # Constructors # @@ -53,13 +64,6 @@ function blockedperm(bt::AbstractBlockTuple) return permmortar(blocks(bt)) end -# Bipartition a vector according to the -# bipartitioned permutation. -# Like `Base.permute!` block out-of-place and blocked. -function blockpermute(v, blockedperm::AbstractBlockPermutation) - return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm)) -end - # blockedpermvcat((4, 3), (2, 1)) function blockedpermvcat( permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing diff --git a/src/matricize.jl b/src/matricize.jl index 94ddb5a..c79e093 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -25,7 +25,7 @@ combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, sty function fuseaxes( axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation ) - axesblocks = blockpermute(axes, blockedperm) + axesblocks = blocks(axes[blockedperm]) return map(block -> ⊗(block...), axesblocks) end @@ -129,7 +129,7 @@ function unmatricize( axes::Tuple{Vararg{AbstractUnitRange}}, biperm::AbstractBlockPermutation{2}, ) - blocked_axes = tuplemortar(blockpermute(axes, biperm)) + blocked_axes = axes[biperm] a_perm = unmatricize(m, blocked_axes) return permutedims(a_perm, invperm(biperm)) end @@ -137,7 +137,7 @@ end function unmatricize!( a::AbstractArray, m::AbstractMatrix, biperm::AbstractBlockPermutation{2} ) - blocked_axes = tuplemortar(blockpermute(axes(a), biperm)) + blocked_axes = axes[biperm] a_perm = unmatricize(m, blocked_axes) return permutedims!(a, a_perm, invperm(biperm)) end diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index e9198a9..961352c 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -135,6 +135,14 @@ using TensorAlgebra: p = blockedpermvcat((3, 2), (..,), 1) @test p == blockedpermvcat((3, 2), (), (1,)) + + # blockpermute + t = (1, 2, 3, 4) + pblocks = tuplemortar(((4, 3), (), (1, 2))) + p = blockedperm(pblocks) + @test (@constinferred blockpermute(t, p)) isa BlockedTuple{3,(2, 0, 2),NTuple{4,Int64}} + @test blockpermute(t, p) == pblocks + @test t[p] == pblocks end @testset "BlockedTrivialPermutation" begin From 128bfeeed082e2939ab3abdb39ea336058dc32b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 4 Apr 2025 11:54:32 -0400 Subject: [PATCH 08/18] more tests --- test/test_blockedpermutation.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index 961352c..d963cb9 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -143,6 +143,8 @@ using TensorAlgebra: @test (@constinferred blockpermute(t, p)) isa BlockedTuple{3,(2, 0, 2),NTuple{4,Int64}} @test blockpermute(t, p) == pblocks @test t[p] == pblocks + @test pblocks[p] == tuplemortar(((2, 1), (), (4, 3))) + @test p[p] == tuplemortar(((2, 1), (), (4, 3))) end @testset "BlockedTrivialPermutation" begin From 4f3d2a59085a7cc90130bd4a651de0c8971af2c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 4 Apr 2025 14:26:49 -0400 Subject: [PATCH 09/18] fix blockpermute --- src/contract/allocate_output.jl | 113 +------------------------------- src/factorizations.jl | 12 ++-- src/matricize.jl | 63 ++++++++++-------- test/test_blockedpermutation.jl | 1 + 4 files changed, 44 insertions(+), 145 deletions(-) diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index 32226cb..3fa1c02 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -11,121 +11,12 @@ function output_axes( biperm2::AbstractBlockPermutation{2}, α::Number=one(Bool), ) - axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1) - axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2) + axes_codomain, axes_contracted = blocks(axes(a1)[biperm1]) + axes_contracted2, axes_domain = blocks(axes(a2)[biperm2]) @assert axes_contracted == axes_contracted2 return genperm((axes_codomain..., axes_domain...), invperm(Tuple(biperm_dest))) end -# Inner-product contraction. -# TODO: Use `ArrayLayouts`-like `MulAdd` object, -# i.e. `ContractAdd`? -function output_axes( - ::typeof(contract), - perm_dest::AbstractBlockPermutation{0}, - a1::AbstractArray, - perm1::AbstractBlockPermutation{1}, - a2::AbstractArray, - perm2::AbstractBlockPermutation{1}, - α::Number=one(Bool), -) - axes_contracted = blockpermute(axes(a1), perm1) - axes_contracted′ = blockpermute(axes(a2), perm2) - @assert axes_contracted == axes_contracted′ - return () -end - -# Vec-mat. -function output_axes( - ::typeof(contract), - perm_dest::AbstractBlockPermutation{1}, - a1::AbstractArray, - perm1::AbstractBlockPermutation{1}, - a2::AbstractArray, - biperm2::AbstractBlockPermutation{2}, - α::Number=one(Bool), -) - (axes_contracted,) = blockpermute(axes(a1), perm1) - axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2) - @assert axes_contracted == axes_contracted′ - return genperm((axes_dest...,), invperm(Tuple(perm_dest))) -end - -# Mat-vec. -function output_axes( - ::typeof(contract), - perm_dest::AbstractBlockPermutation{1}, - a1::AbstractArray, - perm1::AbstractBlockPermutation{2}, - a2::AbstractArray, - biperm2::AbstractBlockPermutation{1}, - α::Number=one(Bool), -) - axes_dest, axes_contracted = blockpermute(axes(a1), perm1) - (axes_contracted′,) = blockpermute(axes(a2), biperm2) - @assert axes_contracted == axes_contracted′ - return genperm((axes_dest...,), invperm(Tuple(perm_dest))) -end - -# Outer product. -function output_axes( - ::typeof(contract), - biperm_dest::AbstractBlockPermutation{2}, - a1::AbstractArray, - perm1::AbstractBlockPermutation{1}, - a2::AbstractArray, - perm2::AbstractBlockPermutation{1}, - α::Number=one(Bool), -) - @assert istrivialperm(Tuple(perm1)) - @assert istrivialperm(Tuple(perm2)) - axes_dest = (axes(a1)..., axes(a2)...) - return genperm(axes_dest, invperm(Tuple(biperm_dest))) -end - -# Array-scalar contraction. -function output_axes( - ::typeof(contract), - perm_dest::AbstractBlockPermutation{1}, - a1::AbstractArray, - perm1::AbstractBlockPermutation{1}, - a2::AbstractArray, - perm2::AbstractBlockPermutation{0}, - α::Number=one(Bool), -) - @assert istrivialperm(Tuple(perm1)) - axes_dest = axes(a1) - return genperm(axes_dest, invperm(Tuple(perm_dest))) -end - -# Scalar-array contraction. -function output_axes( - ::typeof(contract), - perm_dest::AbstractBlockPermutation{1}, - a1::AbstractArray, - perm1::AbstractBlockPermutation{0}, - a2::AbstractArray, - perm2::AbstractBlockPermutation{1}, - α::Number=one(Bool), -) - @assert istrivialperm(Tuple(perm2)) - axes_dest = axes(a2) - return genperm(axes_dest, invperm(Tuple(perm_dest))) -end - -# Scalar-scalar contraction. -function output_axes( - ::typeof(contract), - perm_dest::AbstractBlockPermutation{0}, - a1::AbstractArray, - perm1::AbstractBlockPermutation{0}, - a2::AbstractArray, - perm2::AbstractBlockPermutation{0}, - α::Number=one(Bool), -) - return () -end - # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( diff --git a/src/factorizations.jl b/src/factorizations.jl index 9a56dd2..047f379 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -47,7 +47,7 @@ function qr( Q, R = full ? qr_full!(A_mat; kwargs...) : qr_compact!(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_Q = tuplemortar((axes_codomain, (axes(Q, 2),))) axes_R = tuplemortar(((axes(R, 1),), axes_domain)) return unmatricize(Q, axes_Q), unmatricize(R, axes_R) @@ -83,7 +83,7 @@ function lq( L, Q = full ? lq_full!(A_mat; kwargs...) : lq_compact!(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_L = tuplemortar((axes_codomain, (axes(L, ndims(L)),))) axes_Q = tuplemortar(((axes(Q, 1),), axes_domain)) return unmatricize(L, axes_L), unmatricize(Q, axes_Q) @@ -131,7 +131,7 @@ function eigen( end # matrix to tensor - axes_codomain, = blockpermute(axes(A), biperm) + axes_codomain, = blocks(axes(A)[biperm]) axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) return D, unmatricize(V, axes_V) end @@ -205,7 +205,7 @@ function svd( end # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) return unmatricize(U, axes_U), S, unmatricize(Vᴴ, axes_Vᴴ) @@ -254,7 +254,7 @@ end function left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) A_mat = matricize(A, biperm) N = left_null!(A_mat; kwargs...) - axes_codomain, _ = blockpermute(axes(A), biperm) + axes_codomain = first(blocks(axes(A)[biperm])) axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) N_tensor = unmatricize(N, axes_N) return N_tensor @@ -284,7 +284,7 @@ end function right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) A_mat = matricize(A, biperm) Nᴴ = right_null!(A_mat; kwargs...) - _, axes_domain = blockpermute(axes(A), biperm) + axes_domain = last(blocks(axes(A)[biperm])) axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) return unmatricize(Nᴴ, axes_Nᴴ) end diff --git a/src/matricize.jl b/src/matricize.jl index c79e093..5e88a99 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -60,28 +60,26 @@ end # matrix factorizations assume copy # maybe: copy=false kwarg -# default is reshape -function matricize( - ::ReshapeFusion, - a::AbstractArray, - row_axis::AbstractUnitRange, - col_axis::AbstractUnitRange, -) - return reshape(a, row_axis, col_axis) +function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2}) + return matricize(FusionStyle(a), a, biperm) end -function matricize(::ReshapeFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}) - axes_fused = fuseaxes(axes(a), biperm) - return matricize(ReshapeFusion(), a, axes_fused...) +function matricize( + style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + a_perm = permutedims(a, biperm) + return matricize(style, a_perm, trivialperm(biperm)) end -function matricize(a::AbstractArray, tp::BlockedTrivialPermutation{2}) - return matricize(FusionStyle(a), a, tp) +function matricize( + style::FusionStyle, a::AbstractArray, biperm::BlockedTrivialPermutation{2} +) + return throw(MethodError(matricize, Tuple{typeof(style),typeof(a),typeof(biperm)})) end -function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2}) - a_perm = permutedims(a, biperm) # includes copy - return matricize(a_perm, trivialperm(biperm)) +# default is reshape +function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}) + return reshape(a, fuseaxes(axes(a), biperm)...) end function matricize(a::AbstractArray, bt::AbstractBlockTuple{2}) @@ -97,6 +95,25 @@ function matricize(a::AbstractArray, permblocks...) end # ==================================== unmatricize ======================================= +function unmatricize( + m::AbstractMatrix, + axes::Tuple{Vararg{AbstractUnitRange}}, + biperm::AbstractBlockPermutation{2}, +) + return unmatricize(FusionStyle(m), m, axes, biperm) +end + +function unmatricize( + ::FusionStyle, + m::AbstractMatrix, + axes::Tuple{Vararg{AbstractUnitRange}}, + biperm::AbstractBlockPermutation{2}, +) + blocked_axes = axes[biperm] + a_perm = unmatricize(m, blocked_axes) + return permutedims(a_perm, invperm(biperm)) +end + function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...) return reshape(m, Base.to_shape.(axes)...) end @@ -106,7 +123,7 @@ function unmatricize( m::AbstractMatrix, blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}, ) - return unmatricize(ReshapeFusion(), m, blocked_axes...) + return reshape(m, Base.to_shape.(Tuple(blocked_axes))...) end function unmatricize( @@ -124,20 +141,10 @@ function unmatricize( return unmatricize(m, blocked_axes) end -function unmatricize( - m::AbstractMatrix, - axes::Tuple{Vararg{AbstractUnitRange}}, - biperm::AbstractBlockPermutation{2}, -) - blocked_axes = axes[biperm] - a_perm = unmatricize(m, blocked_axes) - return permutedims(a_perm, invperm(biperm)) -end - function unmatricize!( a::AbstractArray, m::AbstractMatrix, biperm::AbstractBlockPermutation{2} ) - blocked_axes = axes[biperm] + blocked_axes = axes(a)[biperm] a_perm = unmatricize(m, blocked_axes) return permutedims!(a, a_perm, invperm(biperm)) end diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index d963cb9..016011f 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -10,6 +10,7 @@ using TensorAlgebra: BlockedTuple, blockedperm, blockedperm_indexin, + blockpermute, blockedtrivialperm, blockedpermvcat, permmortar, From 1fd5a63f07bc3650faf62ad1014c8bb33a00184a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 4 Apr 2025 16:51:59 -0400 Subject: [PATCH 10/18] impose 2 tuple blocks --- src/matricize.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index 5e88a99..427afd3 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -90,8 +90,8 @@ function matricize(::AbstractArray, ::AbstractBlockTuple) throw(ArgumentError("Invalid axis permutation")) end -function matricize(a::AbstractArray, permblocks...) - return matricize(a, blockedpermvcat(permblocks...; length=Val(ndims(a)))) +function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple) + return matricize(a, blockedpermvcat(permblock1, permblock2; length=Val(ndims(a)))) end # ==================================== unmatricize ======================================= From e1a4b4fe3bd33fe9722f25b2c39a966801192090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 4 Apr 2025 16:56:45 -0400 Subject: [PATCH 11/18] pass tests but blockarrays --- src/factorizations.jl | 72 ++++++++++++++++++++++--------------------- src/matricize.jl | 3 +- test/test_basics.jl | 6 ++-- 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index d6a8889..f2c9b5b 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -296,7 +296,7 @@ end """ left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P - left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P + left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> W, P Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -312,23 +312,23 @@ function left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return left_polar(A, biperm; kwargs...) end -function left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) +function left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization W, P = left_polar!(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_W = (axes_codomain..., axes(W, 2)) - axes_P = (axes(P, 1), axes_domain...) - return splitdims(W, axes_W), splitdims(P, axes_P) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_W = tuplemortar((axes_codomain, (axes(W, 2),))) + axes_P = tuplemortar(((axes(P, 1),), axes_domain)) + return unmatricize(W, axes_W), unmatricize(P, axes_P) end """ right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W - right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W + right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> P, W Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -344,23 +344,23 @@ function right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return right_polar(A, biperm; kwargs...) end -function right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) +function right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization P, W = right_polar!(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_P = (axes_codomain..., axes(P, ndims(P))) - axes_W = (axes(W, 1), axes_domain...) - return splitdims(P, axes_P), splitdims(W, axes_W) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_P = tuplemortar((axes_codomain, (axes(P, ndims(P)),))) + axes_W = tuplemortar(((axes(W, 1),), axes_domain)) + return unmatricize(P, axes_P), unmatricize(W, axes_W) end """ left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C - left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C + left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> V, C Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -376,23 +376,23 @@ function left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; k biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return left_orth(A, biperm; kwargs...) end -function left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) +function left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization V, C = left_orth!(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_V = (axes_codomain..., axes(V, 2)) - axes_C = (axes(C, 1), axes_domain...) - return splitdims(V, axes_V), splitdims(C, axes_C) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_V = tuplemortar((axes_codomain, (axes(V, 2),))) + axes_C = tuplemortar(((axes(C, 1),), axes_domain)) + return unmatricize(V, axes_V), unmatricize(C, axes_C) end """ right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V - right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V + right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> C, V Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -408,23 +408,23 @@ function right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return right_orth(A, biperm; kwargs...) end -function right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) +function right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization P, W = right_orth!(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_P = (axes_codomain..., axes(P, ndims(P))) - axes_W = (axes(W, 1), axes_domain...) - return splitdims(P, axes_P), splitdims(W, axes_W) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_P = tuplemortar((axes_codomain, (axes(P, ndims(P)),))) + axes_W = tuplemortar(((axes(W, 1),), axes_domain)) + return unmatricize(P, axes_P), unmatricize(W, axes_W) end """ factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y - factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y + factorize(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y Compute the decomposition of a generic N-dimensional array, by interpreting it as a linear map from the domain to the codomain indices. These can be specified either via @@ -440,16 +440,18 @@ function factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; k biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return factorize(A, biperm; kwargs...) end -function factorize(A::AbstractArray, biperm::BlockedPermutation{2}; orth=:left, kwargs...) +function factorize( + A::AbstractArray, biperm::AbstractBlockPermutation{2}; orth=:left, kwargs... +) # tensor to matrix - A_mat = fusedims(A, biperm) + A_mat = matricize(A, biperm) # factorization X, Y = (orth == :left ? left_orth! : right_orth!)(A_mat; kwargs...) # matrix to tensor - axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_X = (axes_codomain..., axes(X, ndims(X))) - axes_Y = (axes(Y, 1), axes_domain...) - return splitdims(X, axes_X), splitdims(Y, axes_Y) + axes_codomain, axes_domain = blocks(axes(A)[biperm]) + axes_X = tuplemortar((axes_codomain, (axes(X, ndims(X)),))) + axes_Y = tuplemortar(((axes(Y, 1),), axes_domain)) + return unmatricize(X, axes_X), unmatricize(Y, axes_Y) end diff --git a/src/matricize.jl b/src/matricize.jl index 427afd3..f652bb9 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -79,7 +79,8 @@ end # default is reshape function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}) - return reshape(a, fuseaxes(axes(a), biperm)...) + new_axes = fuseaxes(axes(a), biperm) + return reshape(a, Base.to_shape.(new_axes)...) end function matricize(a::AbstractArray, bt::AbstractBlockTuple{2}) diff --git a/test/test_basics.jl b/test/test_basics.jl index 70cd565..b1f6b5e 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -35,7 +35,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_fused = matricize(a, (3, 1), (2, 4)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) - a_fused = matricize(a, (3, 1, 2), 4) + a_fused = matricize(a, (3, 1, 2), (4,)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (24, 5)) a_fused = matricize(a, (..,), (3, 1)) @@ -52,8 +52,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test eltype(a_fused) === elt @test a_fused ≈ reshape(a, (120, 1)) - @test_throws ArgumentError matricize(a, (1, 2), (3,), (4,)) - @test_throws ArgumentError matricize(a, (1, 2, 3, 4)) + @test_throws MethodError matricize(a, (1, 2), (3,), (4,)) + @test_throws MethodError matricize(a, (1, 2, 3, 4)) v = ones(elt, 2) a_fused = matricize(v, (1,), ()) From e9321cd0f8ac74f8310f6404907fd39c84e601d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 4 Apr 2025 17:02:46 -0400 Subject: [PATCH 12/18] fix docs version --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index c029cf2..661cea0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [compat] Documenter = "1.8.1" Literate = "2.20.1" -TensorAlgebra = "0.2.0" +TensorAlgebra = "0.3.0" From ee446b79fac60724041f3baa98b15088b7a39a8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 7 Apr 2025 12:46:02 -0400 Subject: [PATCH 13/18] add test with zero axis --- test/test_factorizations.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 0c5ccdf..c774c96 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -147,6 +147,20 @@ end @test A ≈ A′ @test size(U, 1) * size(U, 2) == size(U, 3) # U is unitary @test size(Vᴴ, 1) == size(Vᴴ, 2) * size(Vᴴ, 3) # V is unitary + + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full=true) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) + @test A ≈ A′ + @test size(Vᴴ, 1) == 1 + + U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full=true) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (:u,), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) + @test A ≈ A′ + @test size(U, 2) == 1 end @testset "Compact SVD ($T)" for T in elts @@ -166,6 +180,20 @@ end Svals = @constinferred svdvals(A, labels_A, labels_U, labels_Vᴴ) @test Svals ≈ diag(S) + + U, S, Vᴴ = @constinferred svd(A, labels_A, labels_A, (); full=false) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (labels_A..., :u), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v,)) + @test A ≈ A′ + @test size(U, ndims(U)) == 1 == size(Vᴴ, 1) + + U, S, Vᴴ = @constinferred svd(A, labels_A, (), labels_A; full=false) + @test A == Acopy # should not have altered initial array + US, labels_US = contract(U, (:u,), S, (:u, :v)) + A′ = contract(labels_A, US, labels_US, Vᴴ, (:v, labels_A...)) + @test A ≈ A′ + @test size(U, 1) == 1 == size(Vᴴ, 1) end @testset "Truncated SVD ($T)" for T in elts From 877d69505baf0523fea9d446f4c6052d457baea5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 7 Apr 2025 12:48:50 -0400 Subject: [PATCH 14/18] define trivial_axis --- src/matricize.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index f652bb9..fee3b94 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -1,5 +1,7 @@ using LinearAlgebra: Diagonal +using BlockArrays: AbstractBlockedUnitRange, blockedrange + using TensorProducts: ⊗ # ===================================== FusionStyle ====================================== @@ -22,11 +24,15 @@ combine_fusion_styles(::FusionStyle, ::FusionStyle) = ReshapeFusion() combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles) # ======================================= misc ======================================== +trivial_axis(::Tuple{}) = Base.OneTo(1) +trivial_axis(::Tuple{Vararg{AbstractUnitRange}}) = Base.OneTo(1) +trivial_axis(::Tuple{Vararg{AbstractBlockedUnitRange}}) = blockedrange([1]) + function fuseaxes( axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation ) axesblocks = blocks(axes[blockedperm]) - return map(block -> ⊗(block...), axesblocks) + return map(block -> isempty(block) ? trivial_axis(axes) : ⊗(block...), axesblocks) end # define permutedims with a BlockedPermuation. Default is to flatten it. @@ -80,7 +86,7 @@ end # default is reshape function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}) new_axes = fuseaxes(axes(a), biperm) - return reshape(a, Base.to_shape.(new_axes)...) + return reshape(a, new_axes...) end function matricize(a::AbstractArray, bt::AbstractBlockTuple{2}) @@ -116,7 +122,7 @@ function unmatricize( end function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...) - return reshape(m, Base.to_shape.(axes)...) + return reshape(m, axes...) end function unmatricize( @@ -124,7 +130,7 @@ function unmatricize( m::AbstractMatrix, blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}, ) - return reshape(m, Base.to_shape.(Tuple(blocked_axes))...) + return reshape(m, Tuple(blocked_axes)...) end function unmatricize( From b6346fd4d99ceda87e724d1a670319911635fd3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 7 Apr 2025 14:42:15 -0400 Subject: [PATCH 15/18] clean imports --- Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 8 -------- test/test_aqua.jl | 6 ++++-- test/test_basics.jl | 17 +++++------------ test/test_blockarrays_contract.jl | 6 ++++-- test/test_exports.jl | 4 +++- test/test_factorizations.jl | 9 +++++---- test/test_matrixalgebra.jl | 3 ++- 9 files changed, 25 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index 7de5232..288b1e9 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ BlockArrays = "1.5.0" EllipsisNotation = "1.8.0" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.1.1" -TensorProducts = "0.1.0" +TensorProducts = "0.1.5" TupleTools = "1.6.0" TypeParameterAccessors = "0.2.1, 0.3" julia = "1.10" diff --git a/examples/Project.toml b/examples/Project.toml index b12e0b4..0c257a8 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [compat] -TensorAlgebra = "0.2.0" +TensorAlgebra = "0.3.0" diff --git a/test/Project.toml b/test/Project.toml index c2e3656..0223db5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,16 +2,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" -JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" -LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" -SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -21,16 +17,12 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Aqua = "0.8.9" BlockArrays = "1.4.0" EllipsisNotation = "1.8.0" -JLArrays = "0.2.0" -LabelledNumbers = "0.1.1" LinearAlgebra = "<0.0.1, 1" MatrixAlgebraKit = "0.1" -Pkg = "<0.0.1, 1" Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -SymmetrySectors = "0.1" TensorAlgebra = "0.3.0" TensorOperations = "5.1.4" Test = "1.10" diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 5c6c040..9e2d2ee 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,7 +1,9 @@ -using TensorAlgebra: TensorAlgebra -using Aqua: Aqua using Test: @testset +using Aqua: Aqua + +using TensorAlgebra: TensorAlgebra + @testset "Code quality (Aqua.jl)" begin # TODO: fix and re-enable ambiguity checks Aqua.test_all(TensorAlgebra; ambiguities=false, piracies=false) diff --git a/test/test_basics.jl b/test/test_basics.jl index b1f6b5e..4e42d3d 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,18 +1,11 @@ +using Test: @test, @test_broken, @test_throws, @testset + using EllipsisNotation: var".." -using LinearAlgebra: norm using StableRNGs: StableRNG -using TensorAlgebra: - blockedpermvcat, - contract, - contract!, - matricize, - qr, - svd, - tuplemortar, - unmatricize, - unmatricize! using TensorOperations: TensorOperations -using Test: @test, @test_broken, @test_throws, @testset + +using TensorAlgebra: + blockedpermvcat, contract, contract!, matricize, tuplemortar, unmatricize, unmatricize! default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) diff --git a/test/test_blockarrays_contract.jl b/test/test_blockarrays_contract.jl index c3b3af8..f337cc7 100644 --- a/test/test_blockarrays_contract.jl +++ b/test/test_blockarrays_contract.jl @@ -1,8 +1,10 @@ -using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize -using TensorAlgebra: contract using Random: randn! using Test: @test, @test_broken, @testset +using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize + +using TensorAlgebra: contract + function randn_blockdiagonal(elt::Type, axes::Tuple) a = zeros(elt, axes) blockdiaglength = minimum(blocksize(a)) diff --git a/test/test_exports.jl b/test/test_exports.jl index 8dba1bf..0c7b00b 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -1,5 +1,7 @@ -using TensorAlgebra: TensorAlgebra using Test: @test, @testset + +using TensorAlgebra: TensorAlgebra + @testset "Test exports" begin exports = [ :TensorAlgebra, diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index c774c96..efb81fc 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,7 +1,10 @@ -using Test: @test, @testset, @inferred +using LinearAlgebra: LinearAlgebra, norm, diag +using Test: @test, @testset + using TestExtras: @constinferred + +using MatrixAlgebraKit: truncrank using TensorAlgebra: - TensorAlgebra, contract, eigen, eigvals, @@ -18,8 +21,6 @@ using TensorAlgebra: right_polar, svd, svdvals -using MatrixAlgebraKit: truncrank -using LinearAlgebra: LinearAlgebra, norm, diag elts = (Float64, ComplexF64) diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 7a391c3..7ee3598 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -1,7 +1,8 @@ using LinearAlgebra: I, diag, isposdef -using TensorAlgebra.MatrixAlgebra: MatrixAlgebra using Test: @test, @testset +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra + elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "TensorAlgebra.MatrixAlgebra (elt=$elt)" for elt in elts From 05070b8b3497401ceb34461ae7e286179d4e6b02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 7 Apr 2025 14:49:43 -0400 Subject: [PATCH 16/18] retrieve _permutedims --- src/BaseExtensions/BaseExtensions.jl | 1 + src/BaseExtensions/permutedims.jl | 21 +++++++++++++++++++++ src/matricize.jl | 12 +++++++----- 3 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 src/BaseExtensions/permutedims.jl diff --git a/src/BaseExtensions/BaseExtensions.jl b/src/BaseExtensions/BaseExtensions.jl index 5109cbb..c994fd8 100644 --- a/src/BaseExtensions/BaseExtensions.jl +++ b/src/BaseExtensions/BaseExtensions.jl @@ -1,3 +1,4 @@ module BaseExtensions include("indexin.jl") +include("permutedims.jl") end diff --git a/src/BaseExtensions/permutedims.jl b/src/BaseExtensions/permutedims.jl new file mode 100644 index 0000000..19f8fd7 --- /dev/null +++ b/src/BaseExtensions/permutedims.jl @@ -0,0 +1,21 @@ +# Workaround for https://github.com/JuliaLang/julia/issues/52615. +# Fixed by https://github.com/JuliaLang/julia/pull/52623. +# TODO remove once support for Julia 1.10 is dropped +function _permutedims!( + a_dest::AbstractArray{<:Any,N}, a_src::AbstractArray{<:Any,N}, perm::Tuple{Vararg{Int,N}} +) where {N} + permutedims!(a_dest, a_src, perm) + return a_dest +end +function _permutedims!( + a_dest::AbstractArray{<:Any,0}, a_src::AbstractArray{<:Any,0}, perm::Tuple{} +) + a_dest[] = a_src[] + return a_dest +end +function _permutedims(a::AbstractArray{<:Any,N}, perm::Tuple{Vararg{Int,N}}) where {N} + return permutedims(a, perm) +end +function _permutedims(a::AbstractArray{<:Any,0}, perm::Tuple{}) + return copy(a) +end diff --git a/src/matricize.jl b/src/matricize.jl index fee3b94..d357c9b 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -3,6 +3,7 @@ using LinearAlgebra: Diagonal using BlockArrays: AbstractBlockedUnitRange, blockedrange using TensorProducts: ⊗ +using .BaseExtensions: _permutedims, _permutedims! # ===================================== FusionStyle ====================================== abstract type FusionStyle end @@ -35,30 +36,31 @@ function fuseaxes( return map(block -> isempty(block) ? trivial_axis(axes) : ⊗(block...), axesblocks) end +# TODO remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. function Base.permutedims(a::AbstractArray, biperm::AbstractBlockPermutation) - return permutedims(a, Tuple(biperm)) + return _permutedims(a, Tuple(biperm)) end # solve ambiguities function Base.permutedims(a::StridedArray, biperm::AbstractBlockPermutation) - return permutedims(a, Tuple(biperm)) + return _permutedims(a, Tuple(biperm)) end function Base.permutedims(a::Diagonal, biperm::AbstractBlockPermutation) - return permutedims(a, Tuple(biperm)) + return _permutedims(a, Tuple(biperm)) end function Base.permutedims!( a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation ) - return permutedims!(a, b, Tuple(biperm)) + return _permutedims!(a, b, Tuple(biperm)) end # solve ambiguities function Base.permutedims!( a::Array{T,N}, b::StridedArray{T,N}, biperm::AbstractBlockPermutation ) where {T,N} - return permutedims!(a, b, Tuple(biperm)) + return _permutedims!(a, b, Tuple(biperm)) end # ===================================== matricize ======================================== From fbc93c4a0e3c5e3cbec5115643df7a447a018459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 7 Apr 2025 15:29:02 -0400 Subject: [PATCH 17/18] remove FusionStyle(::AbstractArray{<:Any,0}) --- src/matricize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matricize.jl b/src/matricize.jl index d357c9b..a64d3db 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -16,10 +16,10 @@ function FusionStyle(a::AbstractArray, t::Tuple{Vararg{AbstractUnitRange}}) end # Defaults to ReshapeFusion, a simple reshape -FusionStyle(::AbstractArray{<:Any,0}) = ReshapeFusion() # TBD better solution? FusionStyle(::AbstractUnitRange) = ReshapeFusion() FusionStyle(::AbstractArray, ::ReshapeFusion) = ReshapeFusion() +combine_fusion_styles() = ReshapeFusion() combine_fusion_styles(::Style, ::Style) where {Style<:FusionStyle} = Style() combine_fusion_styles(::FusionStyle, ::FusionStyle) = ReshapeFusion() combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles) From cc042aea347e7a56b9b6e39c5fc304bc579aa38f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 7 Apr 2025 15:44:36 -0400 Subject: [PATCH 18/18] remove matricize(::AbstractBlockTuple --- src/matricize.jl | 8 -------- test/test_basics.jl | 4 ---- 2 files changed, 12 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index a64d3db..49bc949 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -91,14 +91,6 @@ function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPerm return reshape(a, new_axes...) end -function matricize(a::AbstractArray, bt::AbstractBlockTuple{2}) - return matricize(a, blockedperm(bt)) -end - -function matricize(::AbstractArray, ::AbstractBlockTuple) - throw(ArgumentError("Invalid axis permutation")) -end - function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple) return matricize(a, blockedpermvcat(permblock1, permblock2; length=Val(ndims(a)))) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 4e42d3d..fc11630 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -18,10 +18,6 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test eltype(a_fused) === elt @test a_fused ≈ reshape(a, 6, 20) - a_fused = matricize(a, tuplemortar(((1, 2), (3, 4)))) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(a, 6, 20) - a_fused = matricize(a, (1, 2), (3, 4)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(a, 6, 20)