From 35bfe3d37d97c96c0ce45a9e4ca77e9b2665ef1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 20 Feb 2025 11:11:36 -0500 Subject: [PATCH 1/5] contract with AbstractBlockedPermutation{2} --- Project.toml | 2 +- src/contract/blockedperms.jl | 6 +- src/contract/contract.jl | 12 +-- src/contract/contract_matricize/contract.jl | 95 ++------------------- src/fusedims.jl | 1 + src/splitdims.jl | 6 +- 6 files changed, 21 insertions(+), 101 deletions(-) diff --git a/Project.toml b/Project.toml index d343c94..a967c86 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.10" +version = "0.1.11" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index bc0679e..28b0bce 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) permblocks_dest = (perm_codomain_dest, perm_domain_dest) - biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...) + biperm_dest = blockedperm(permblocks_dest...) permblocks1 = (perm_codomain1, perm_domain1) - biperm1 = blockedperm(filter(!isempty, permblocks1)...) + biperm1 = blockedperm(permblocks1...) permblocks2 = (perm_codomain2, perm_domain2) - biperm2 = blockedperm(filter(!isempty, permblocks2)...) + biperm2 = blockedperm(permblocks2...) return biperm_dest, biperm1, biperm2 end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index dae9441..1cfb061 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -13,11 +13,11 @@ default_contract_alg() = Matricize() function contract!( alg::Algorithm, a_dest::AbstractArray, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation{2}, α::Number, β::Number, ) @@ -110,11 +110,11 @@ end function contract( alg::Algorithm, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation{2}, α::Number; kwargs..., ) diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 1750ae2..02d08c9 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -1,103 +1,22 @@ using LinearAlgebra: mul! function contract!( - alg::Matricize, + ::Matricize, a_dest::AbstractArray, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation{2}, α::Number, β::Number, ) a_dest_mat = fusedims(a_dest, biperm_dest) a1_mat = fusedims(a1, biperm1) a2_mat = fusedims(a2, biperm2) - _mul!(a_dest_mat, a1_mat, a2_mat, α, β) + @assert ndims(a1_mat) == 2 + @assert ndims(a2_mat) == 2 + mul!(a_dest_mat, a1_mat, a2_mat, α, β) splitdims!(a_dest, a_dest_mat, biperm_dest) return a_dest end - -# Matrix multiplication. -function _mul!( - a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number -) - mul!(a_dest, a1, a2, α, β) - return a_dest -end - -# Inner product. -function _mul!( - a_dest::AbstractArray{<:Any,0}, - a1::AbstractVector, - a2::AbstractVector, - α::Number, - β::Number, -) - a_dest[] = transpose(a1) * a2 * α + a_dest[] * β - return a_dest -end - -# Vec-mat. -function _mul!( - a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number -) - mul!(transpose(a_dest), transpose(a1), a2, α, β) - return a_dest -end - -# Mat-vec. -function _mul!( - a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number -) - mul!(a_dest, a1, a2, α, β) - return a_dest -end - -# Outer product. -function _mul!( - a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number -) - mul!(a_dest, a1, transpose(a2), α, β) - return a_dest -end - -# Array-scalar contraction. -function _mul!( - a_dest::AbstractVector, - a1::AbstractVector, - a2::AbstractArray{<:Any,0}, - α::Number, - β::Number, -) - α′ = a2[] * α - a_dest .= a1 .* α′ .+ a_dest .* β - return a_dest -end - -# Scalar-array contraction. -function _mul!( - a_dest::AbstractVector, - a1::AbstractArray{<:Any,0}, - a2::AbstractVector, - α::Number, - β::Number, -) - # Preserve the ordering in case of non-commutative algebra. - a_dest .= a1[] .* a2 .* α .+ a_dest .* β - return a_dest -end - -# Scalar-scalar contraction. -function _mul!( - a_dest::AbstractArray{<:Any,0}, - a1::AbstractArray{<:Any,0}, - a2::AbstractArray{<:Any,0}, - α::Number, - β::Number, -) - # Preserve the ordering in case of non-commutative algebra. - a_dest[] = a1[] * a2[] * α + a_dest[] * β - return a_dest -end diff --git a/src/fusedims.jl b/src/fusedims.jl index 2e87346..aa9a7c0 100644 --- a/src/fusedims.jl +++ b/src/fusedims.jl @@ -11,6 +11,7 @@ combine_fusion_styles(style1::Style, style2::Style) where {Style<:FusionStyle} = combine_fusion_styles(style1::FusionStyle, style2::FusionStyle) = ReshapeFusion() combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles) FusionStyle(axis::AbstractUnitRange) = ReshapeFusion() +FusionStyle(::Tuple{}) = ReshapeFusion() function FusionStyle(axes::Tuple{Vararg{AbstractUnitRange}}) return combine_fusion_styles(FusionStyle.(axes)...) end diff --git a/src/splitdims.jl b/src/splitdims.jl index 0554c61..3c71bbf 100644 --- a/src/splitdims.jl +++ b/src/splitdims.jl @@ -49,7 +49,7 @@ end function splitdims( a::AbstractArray, axes_dest::Tuple{Vararg{AbstractUnitRange}}, - blockedperm::BlockedPermutation, + blockedperm::AbstractBlockPermutation, ) # TODO: Pass grouped axes. a_dest_perm = splitdims(a, axes_dest...) @@ -58,9 +58,9 @@ function splitdims( end function splitdims!( - a_dest::AbstractArray, a::AbstractArray, blockedperm::BlockedPermutation + a_dest::AbstractArray, a::AbstractArray, blockedperm::AbstractBlockPermutation ) - axes_dest = map(i -> axes(a_dest, i), Tuple(blockedperm)) + axes_dest = map(i -> axes(a_dest, i), blockedperm) # TODO: Pass grouped axes. a_dest_perm = splitdims(a, axes_dest...) _permutedims!(a_dest, a_dest_perm, invperm(Tuple(blockedperm))) From b47c0a354756427ba70ed0cdadb9267580098f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 20 Feb 2025 17:54:42 -0500 Subject: [PATCH 2/5] fusedims with BlockedPermutation --- src/blockedpermutation.jl | 4 ++-- src/fusedims.jl | 22 ++++++++++++---------- test/test_basics.jl | 25 ++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 663f480..2340b26 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -74,8 +74,8 @@ function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwar return blockedperm(collect_tuple.(permblocks)...; kwargs...) end -function blockedperm(bt::AbstractBlockTuple) - return blockedperm(Val(length(bt)), blocks(bt)...) +function blockedperm(bt::AbstractBlockTuple; length::Union{Val,Nothing}=nothing) + return blockedperm(Val(Base.length(bt)), blocks(bt)...) end function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}}) diff --git a/src/fusedims.jl b/src/fusedims.jl index aa9a7c0..9882c11 100644 --- a/src/fusedims.jl +++ b/src/fusedims.jl @@ -34,7 +34,6 @@ function fusedims(a::AbstractArray, ax::AbstractUnitRange, axes::AbstractUnitRan return fusedims(FusionStyle(a), a, ax, axes...) end -# Overload this version for fusion tensors, array maps, etc. function fusedims( a::AbstractArray, axb::Tuple{Vararg{AbstractUnitRange}}, @@ -43,14 +42,6 @@ function fusedims( return fusedims(a, flatten_tuples((axb, axesblocks...))...) end -# Fix ambiguity issue -fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a - -# TODO: Is this needed? Maybe delete. -function fusedims(a::AbstractArray, permblocks...) - return fusedims(a, blockedperm(permblocks...; length=Val(ndims(a)))) -end - function fuseaxes( axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation ) @@ -68,7 +59,18 @@ function fusedims(a::AbstractArray, blockedperm::BlockedTrivialPermutation) return fusedims(a, axes_fused) end -function fusedims(a::AbstractArray, blockedperm::BlockedPermutation) +# deal with zero-dim case +fusedims(a::AbstractArray{<:Any,0}, t::Tuple{}...) = reshape(a, ntuple(_ -> 1, length(t))) + +function fusedims(a::AbstractArray, blockedperm::AbstractBlockPermutation) + # TBD define permutedims(::AbstractArray, ::AbstractBlockPermutation) + # TBD remove call to BlockedTrivialPermutation? a_perm = _permutedims(a, Tuple(blockedperm)) return fusedims(a_perm, trivialperm(blockedperm)) end + +# fusedims(ones((2,2,2,2)), (3, 1, 2), (4,)) +# fusedims(ones((2,2,2,2)), (3, 1, 2), 4) +function fusedims(a::AbstractArray, permblocks...) + return fusedims(a, blockedperm(permblocks...; length=Val(ndims(a)))) +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 221274b..a1d4864 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,8 @@ using EllipsisNotation: var".." using LinearAlgebra: norm using StableRNGs: StableRNG -using TensorAlgebra: contract, contract!, fusedims, qr, splitdims, svd +using TensorAlgebra: + blockedperm, contract, contract!, fusedims, qr, splitdims, svd, tuplemortar using TensorOperations: TensorOperations using Test: @test, @test_broken, @testset @@ -11,6 +12,16 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin @testset "fusedims (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) + + bt = tuplemortar(((3, 2), (4, 1))) + p = blockedperm(bt) + a_fused = fusedims(a, p) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (3, 2, 4, 1)), (12, 10)) + a_fused = fusedims(a, bt) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (3, 2, 4, 1)), (12, 10)) + a_fused = fusedims(a, (1, 2), (3, 4)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(a, 6, 20) @@ -35,7 +46,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_fused = fusedims(a, (3, 1), ..) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5)) + + a = randn(elt, ()) + a_fused = fusedims(a) + @test a_fused isa Array{elt,0} + @test a_fused ≈ a + a_fused = fusedims(a, ()) + @test a_fused isa Array{elt,1} + a_fused = fusedims(a, (), ()) + @test a_fused isa Array{elt,2} + a_fused = fusedims(a, tuplemortar(((), ()))) + @test a_fused isa Array{elt,2} end + @testset "splitdims (eltype=$elt)" for elt in elts a = randn(elt, 6, 20) a_split = splitdims(a, (2, 3), (5, 4)) From ac3065d8f0e7a4fda1625a938d0e44f3f7c8cd62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 20 Feb 2025 19:08:58 -0500 Subject: [PATCH 3/5] splitdims with BlockedTuple interface --- src/splitdims.jl | 49 ++++++++++++++++++++++++--------------------- test/test_basics.jl | 13 ++++++++++++ 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/splitdims.jl b/src/splitdims.jl index 3c71bbf..01f90eb 100644 --- a/src/splitdims.jl +++ b/src/splitdims.jl @@ -4,32 +4,41 @@ to_axis(a::AbstractUnitRange) = a to_axis(n::Integer) = Base.OneTo(n) function blockedaxes(a::AbstractArray, sizeblocks::Pair...) - axes_a = axes(a) axes_split = tuple.(axes(a)) for (dim, sizeblock) in sizeblocks # TODO: Handle conversion from length to range! axes_split = Base.setindex(axes_split, to_axis.(sizeblock), dim) end - return axes_split + return tuplemortar(axes_split) end -# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2) -function splitdims(::ReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...) +function splitdims(::ReshapeFusion, a::AbstractArray, abt::BlockedTuple) # TODO: Add `uncanonicalizedims`. # TODO: Need `length` since `reshape` doesn't accept `axes`, # maybe make a `reshape_axes` function. - return reshape(a, length.(axes)...) + return reshape(a, Tuple(length.(abt))) +end + +# ambiguity for zero-dim +function splitdims(a::AbstractArray{<:Any,N}, abt::BlockedTuple{N,<:Any,Tuple{}}) where {N} + return splitdims(FusionStyle(a), a, abt) +end + +function splitdims( + a::AbstractArray{<:Any,N}, abt::BlockedTuple{N,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} +) where {N} + return splitdims(FusionStyle(a), a, abt) end # splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2) function splitdims(a::AbstractArray, axes::AbstractUnitRange...) - return splitdims(FusionStyle(a), a, axes...) + return splitdims(a, tuple.(axes)...) end # splitdims(randn(4, 4), (1:2, 1:2), (1:2, 1:2)) function splitdims(a::AbstractArray, axesblocks::Tuple{Vararg{AbstractUnitRange}}...) # TODO: Add `uncanonicalizedims`. - return splitdims(a, flatten_tuples(axesblocks)...) + return splitdims(a, tuplemortar(axesblocks)) end # Fix ambiguity issue @@ -37,32 +46,26 @@ splitdims(a::AbstractArray) = a # splitdims(randn(4, 4), (2, 2), (2, 2)) function splitdims(a::AbstractArray, sizeblocks::Tuple{Vararg{Integer}}...) - return splitdims(a, map(x -> Base.OneTo.(x), sizeblocks)...) + return splitdims(a, tuplemortar(sizeblocks)) end -# splitdims(randn(4, 4), 2 => (1:2, 1:2)) -function splitdims(a::AbstractArray, sizeblocks::Pair...) - return splitdims(a, blockedaxes(a, sizeblocks...)...) +# splitdims(randn(4, 4), tuplemortar(((2, 2), (2, 2)))) +function splitdims( + a::AbstractArray{<:Any,N}, bt::BlockedTuple{N,<:Any,<:Tuple{Vararg{Integer}}} +) where {N} + return splitdims(a, to_axis.(bt)) end -# TODO: Is this needed? -function splitdims( - a::AbstractArray, - axes_dest::Tuple{Vararg{AbstractUnitRange}}, - blockedperm::AbstractBlockPermutation, -) - # TODO: Pass grouped axes. - a_dest_perm = splitdims(a, axes_dest...) - a_dest = _permutedims(a_dest_perm, invperm(Tuple(blockedperm))) - return a_dest +# splitdims(randn(4, 4), 2 => (1:2, 1:2)) +function splitdims(a::AbstractArray, sizeblocks::Pair...) + return splitdims(a, blockedaxes(a, sizeblocks...)) end function splitdims!( a_dest::AbstractArray, a::AbstractArray, blockedperm::AbstractBlockPermutation ) axes_dest = map(i -> axes(a_dest, i), blockedperm) - # TODO: Pass grouped axes. - a_dest_perm = splitdims(a, axes_dest...) + a_dest_perm = splitdims(a, axes_dest) _permutedims!(a_dest, a_dest_perm, invperm(Tuple(blockedperm))) return a_dest end diff --git a/test/test_basics.jl b/test/test_basics.jl index a1d4864..2ec6ee3 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -61,6 +61,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "splitdims (eltype=$elt)" for elt in elts a = randn(elt, 6, 20) + + a_split = splitdims(a, tuplemortar(((1:2, 1:3), (1:5, 1:4)))) + @test a_split isa Array{elt,4} + @test a_split ≈ reshape(a, (2, 3, 5, 4)) + + a_split = splitdims(a, tuplemortar(((2, 3), (5, 4)))) + @test a_split isa Array{elt,4} + @test a_split ≈ reshape(a, (2, 3, 5, 4)) + a_split = splitdims(a, (2, 3), (5, 4)) @test eltype(a_split) === elt @test a_split ≈ reshape(a, (2, 3, 5, 4)) @@ -85,6 +94,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_split = splitdims(a, 1 => (1:2, 1:3)) @test eltype(a_split) === elt @test a_split ≈ reshape(a, (2, 3, 20)) + + a_split = splitdims(a) + @test a_split isa Array{elt,2} + @test a_split ≈ a end using TensorOperations: TensorOperations @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts From 8b3ebbd4be10e4bb203b4edb89e0eaef35dcb9c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 20 Feb 2025 19:28:34 -0500 Subject: [PATCH 4/5] write factorizations with AbstractBlockedPermutation --- src/factorizations.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index a017ca1..1ebf5e0 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -1,15 +1,15 @@ using ArrayLayouts: LayoutMatrix using LinearAlgebra: LinearAlgebra, Diagonal -function qr(a::AbstractArray, biperm::BlockedPermutation{2}) +function qr(a::AbstractArray, biperm::AbstractBlockPermutation{2}) a_matricized = fusedims(a, biperm) # TODO: Make this more generic, allow choosing thin or full, # make sure this works on GPU. q_fact, r_matricized = LinearAlgebra.qr(a_matricized) q_matricized = typeof(a_matricized)(q_fact) axes_codomain, axes_domain = blockpermute(axes(a), biperm) - axes_q = (axes_codomain..., axes(q_matricized, 2)) - axes_r = (axes(r_matricized, 1), axes_domain...) + axes_q = tuplemortar((axes_codomain, (axes(q_matricized, 2),))) + axes_r = tuplemortar(((axes(r_matricized, 1),), axes_domain)) q = splitdims(q_matricized, axes_q) r = splitdims(r_matricized, axes_r) return q, r @@ -22,15 +22,15 @@ function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain) ) end -function svd(a::AbstractArray, biperm::BlockedPermutation{2}) +function svd(a::AbstractArray, biperm::AbstractBlockPermutation{2}) a_matricized = fusedims(a, biperm) usv_matricized = LinearAlgebra.svd(a_matricized) u_matricized = usv_matricized.U s_diag = usv_matricized.S v_matricized = usv_matricized.Vt axes_codomain, axes_domain = blockpermute(axes(a), biperm) - axes_u = (axes_codomain..., axes(u_matricized, 2)) - axes_v = (axes(v_matricized, 1), axes_domain...) + axes_u = tuplemortar((axes_codomain, (axes(u_matricized, 2),))) + axes_v = tuplemortar(((axes(v_matricized, 1),), axes_domain)) u = splitdims(u_matricized, axes_u) # TODO: Use `DiagonalArrays.diagonal` to make it more general. s = Diagonal(s_diag) From a8d078ba7c871fc8ea6c31365a20e93465acc4a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 20 Feb 2025 20:05:32 -0500 Subject: [PATCH 5/5] cleaning --- src/splitdims.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/splitdims.jl b/src/splitdims.jl index 01f90eb..78684b4 100644 --- a/src/splitdims.jl +++ b/src/splitdims.jl @@ -25,9 +25,9 @@ function splitdims(a::AbstractArray{<:Any,N}, abt::BlockedTuple{N,<:Any,Tuple{}} end function splitdims( - a::AbstractArray{<:Any,N}, abt::BlockedTuple{N,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} + a::AbstractArray{<:Any,N}, bt::BlockedTuple{N,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} ) where {N} - return splitdims(FusionStyle(a), a, abt) + return splitdims(FusionStyle(a), a, bt) end # splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2)