diff --git a/Project.toml b/Project.toml index 56df2a7..c0e814d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.2" +version = "0.1.3" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 1dd261a..bc5771a 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -2,9 +2,9 @@ module TensorAlgebra export contract, contract! +include("blockedtuple.jl") include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") -include("blockedtuple.jl") include("fusedims.jl") include("splitdims.jl") include("contract/contract.jl") diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index c13b68a..380eccc 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -12,82 +12,32 @@ end _flatten_tuples() = () flatten_tuples(ts::Tuple) = _flatten_tuples(ts...) -_blocklength(blocklengths::Tuple{Vararg{Int}}) = length(blocklengths) -function _blockfirsts(blocklengths::Tuple{Vararg{Int}}) - return ntuple(_blocklength(blocklengths)) do i - prev_blocklast = - isone(i) ? zero(eltype(blocklengths)) : _blocklasts(blocklengths)[i - 1] - return prev_blocklast + 1 - end -end -_blocklasts(blocklengths::Tuple{Vararg{Int}}) = cumsum(blocklengths) - collect_tuple(x) = (x,) collect_tuple(x::Ellipsis) = x collect_tuple(t::Tuple) = t -const TupleOfTuples{N} = Tuple{Vararg{Tuple{Vararg{Int}},N}} - -abstract type AbstractBlockedPermutation{BlockLength,Length} end - -BlockArrays.blocks(blockedperm::AbstractBlockedPermutation) = error("Not implemented") - -function Base.Tuple(blockedperm::AbstractBlockedPermutation) - return flatten_tuples(blocks(blockedperm)) -end - -function BlockArrays.blocklengths(blockedperm::AbstractBlockedPermutation) - return length.(blocks(blockedperm)) -end - -function BlockArrays.blockfirsts(blockedperm::AbstractBlockedPermutation) - return _blockfirsts(blocklengths(blockedperm)) -end - -function BlockArrays.blocklasts(blockedperm::AbstractBlockedPermutation) - return _blocklasts(blocklengths(blockedperm)) -end +# +# =============================== AbstractBlockPermutation =============================== +# +abstract type AbstractBlockPermutation{BlockLength} <: AbstractBlockTuple{BlockLength} end -Base.iterate(permblocks::AbstractBlockedPermutation) = iterate(Tuple(permblocks)) -function Base.iterate(permblocks::AbstractBlockedPermutation, state) - return iterate(Tuple(permblocks), state) -end +widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple # Block a permutation based on the specified lengths. # blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1)) # TODO: Optimize with StaticNumbers.jl or generated functions, see: # https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567 function blockperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}}) - starts = _blockfirsts(blocklengths) - stops = _blocklasts(blocklengths) - return blockedperm(ntuple(i -> perm[starts[i]:stops[i]], length(blocklengths))...) -end - -function Base.invperm(blockedperm::AbstractBlockedPermutation) - return blockperm(invperm(Tuple(blockedperm)), blocklengths(blockedperm)) + return blockedperm(BlockedTuple(perm, blocklengths)) end -Base.length(blockedperm::AbstractBlockedPermutation) = length(Tuple(blockedperm)) -function BlockArrays.blocklength(blockedperm::AbstractBlockedPermutation) - return length(blocks(blockedperm)) +function blockperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val) + return blockedperm(BlockedTuple(perm, BlockLengths)) end -function Base.getindex(blockedperm::AbstractBlockedPermutation, i::Int) - return Tuple(blockedperm)[i] -end - -function Base.getindex(blockedperm::AbstractBlockedPermutation, I::AbstractUnitRange) - perm = Tuple(blockedperm) - return [perm[i] for i in I] -end - -function Base.getindex(blockedperm::AbstractBlockedPermutation, b::Block) - return blocks(blockedperm)[Int(b)] -end - -# Like `BlockRange`. -function blockeachindex(blockedperm::AbstractBlockedPermutation) - return ntuple(i -> Block(i), blocklength(blockedperm)) +function Base.invperm(blockedperm::AbstractBlockPermutation) + # use Val to preserve compile time info + return blockperm(invperm(Tuple(blockedperm)), Val(blocklengths(blockedperm))) end # @@ -97,7 +47,7 @@ end # Bipartition a vector according to the # bipartitioned permutation. # Like `Base.permute!` block out-of-place and blocked. -function blockpermute(v, blockedperm::AbstractBlockedPermutation) +function blockpermute(v, blockedperm::AbstractBlockPermutation) return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm)) end @@ -106,8 +56,8 @@ function blockedperm(permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothin return blockedperm(length, permblocks...) end -function blockedperm(length::Nothing, permblocks::Tuple{Vararg{Int}}...) - return blockedperm(Val(sum(Base.length, permblocks; init=zero(Bool))), permblocks...) +function blockedperm(::Nothing, permblocks::Tuple{Vararg{Int}}...) + return blockedperm(Val(sum(length, permblocks; init=zero(Bool))), permblocks...) end # blockedperm((3, 2), 1) == blockedperm((3, 2), (1,)) @@ -119,11 +69,15 @@ function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwar return blockedperm(collect_tuple.(permblocks)...; kwargs...) end +function blockedperm(bt::AbstractBlockTuple) + return blockedperm(Val(length(bt)), blocks(bt)...) +end + function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}}) return maximum(specified_perm) end -function _blockedperm_length(vallength::Val, specified_perm::Tuple{Vararg{Int}}) +function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}}) return value(vallength) end @@ -148,45 +102,69 @@ function blockedperm_indexin(collection, subs...) return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...) end -struct BlockedPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}} <: - AbstractBlockedPermutation{BlockLength,Length} - blocks::Blocks - global function _BlockedPermutation(blocks::TupleOfTuples) - len = sum(length, blocks; init=zero(Bool)) - blocklength = length(blocks) - return new{blocklength,len,typeof(blocks)}(blocks) +# +# ================================== BlockedPermutation ================================== +# + +# for dispatch reason, it is convenient to have BlockLength as the first parameter +struct BlockedPermutation{BlockLength,BlockLengths,Flat} <: + AbstractBlockPermutation{BlockLength} + flat::Flat + + function BlockedPermutation{BlockLength,BlockLengths}( + flat::Tuple + ) where {BlockLength,BlockLengths} + length(flat) != sum(BlockLengths; init=0) && + throw(DimensionMismatch("Invalid total length")) + length(BlockLengths) != BlockLength && + throw(DimensionMismatch("Invalid total blocklength")) + any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length")) + return new{BlockLength,BlockLengths,typeof(flat)}(flat) end end -BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks) +# Base interface +Base.Tuple(blockedperm::BlockedPermutation) = getfield(blockedperm, :flat) -function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...) - @assert value(length) == sum(Base.length, permblocks; init=zero(Bool)) - blockedperm = _BlockedPermutation(permblocks) +# BlockArrays interface +function BlockArrays.blocklengths( + ::Type{<:BlockedPermutation{<:Any,BlockLengths}} +) where {BlockLengths} + return BlockLengths +end + +function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...) + blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}( + flatten_tuples(permblocks) + ) @assert isperm(blockedperm) return blockedperm end +# +# ============================== BlockedTrivialPermutation =============================== +# trivialperm(length::Union{Integer,Val}) = ntuple(identity, length) -struct BlockedTrivialPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}} <: - AbstractBlockedPermutation{BlockLength,Length} - blocks::Blocks - global function _BlockedTrivialPermutation(blocklengths::Tuple{Vararg{Int}}) - len = sum(blocklengths; init=zero(Bool)) - blocklength = length(blocklengths) - permblocks = blocks(blockperm(trivialperm(len), blocklengths)) - return new{blocklength,len,typeof(permblocks)}(permblocks) - end +struct BlockedTrivialPermutation{BlockLength,BlockLengths} <: + AbstractBlockPermutation{BlockLength} end + +Base.Tuple(blockedperm::BlockedTrivialPermutation) = trivialperm(length(blockedperm)) + +# BlockArrays interface +function BlockArrays.blocklengths( + ::Type{<:BlockedTrivialPermutation{<:Any,BlockLengths}} +) where {BlockLengths} + return BlockLengths end -BlockArrays.blocks(blockedperm::BlockedTrivialPermutation) = getfield(blockedperm, :blocks) +blockedperm(tp::BlockedTrivialPermutation) = tp function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}}) - return _BlockedTrivialPermutation(blocklengths) + return BlockedTrivialPermutation{length(blocklengths),blocklengths}() end -function trivialperm(blockedperm::AbstractBlockedPermutation) +function trivialperm(blockedperm::AbstractBlockTuple) return blockedtrivialperm(blocklengths(blockedperm)) end Base.invperm(blockedperm::BlockedTrivialPermutation) = blockedperm diff --git a/src/blockedtuple.jl b/src/blockedtuple.jl index 966c76c..40db9ee 100644 --- a/src/blockedtuple.jl +++ b/src/blockedtuple.jl @@ -1,5 +1,6 @@ -# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl -# like interface +# This file defines an abstract type AbstractBlockTuple and a concrete type BlockedTuple. +# These types allow to store a Tuple of heterogeneous Tuples with a BlockArrays.jl like +# interface. using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange @@ -8,7 +9,17 @@ using TypeParameterAccessors: unspecify_type_parameters # # ================================== AbstractBlockTuple ================================== # -abstract type AbstractBlockTuple end +# AbstractBlockTuple imposes BlockLength as first type parameter for easy dispatch +# it makes no assumotion on storage type +abstract type AbstractBlockTuple{BlockLength} end + +constructorof(type::Type{<:AbstractBlockTuple}) = unspecify_type_parameters(type) +widened_constructorof(type::Type{<:AbstractBlockTuple}) = constructorof(type) + +# Like `BlockRange`. +function blockeachindex(bt::AbstractBlockTuple) + return ntuple(i -> Block(i), blocklength(bt)) +end # Base interface Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),) @@ -22,9 +33,8 @@ Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r] Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)] function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1}) r = Int.(br) - T = unspecify_type_parameters(typeof(bt)) flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]] - return T{blocklengths(bt)[r]}(flat) + return widened_constructorof(typeof(bt))(flat, blocklengths(bt)[r]) end function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1}) return bt[Block(bi)][only(bi.indices)] @@ -33,12 +43,14 @@ end Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt)) Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i) -Base.length(bt::AbstractBlockTuple) = length(Tuple(bt)) - Base.lastindex(bt::AbstractBlockTuple) = length(bt) +Base.length(bt::AbstractBlockTuple) = sum(blocklengths(bt); init=0) + function Base.map(f, bt::AbstractBlockTuple) - return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt))) + BL = blocklengths(bt) + # use Val to preserve compile time knowledge of BL + return widened_constructorof(typeof(bt))(map(f, Tuple(bt)), Val(BL)) end # Broadcast interface @@ -57,19 +69,20 @@ end function Base.copy( bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}} ) where {BlockLengths,BT} - return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...)) + return widened_constructorof(BT)(bc.f.((Tuple.(bc.args))...), Val(BlockLengths)) end # BlockArrays interface +BlockArrays.blockfirsts(::AbstractBlockTuple{0}) = () function BlockArrays.blockfirsts(bt::AbstractBlockTuple) return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1 end function BlockArrays.blocklasts(bt::AbstractBlockTuple) - return cumsum(blocklengths(bt)[begin:end]) + return cumsum(blocklengths(bt)) end -BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt)) +BlockArrays.blocklength(::AbstractBlockTuple{BlockLength}) where {BlockLength} = BlockLength BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt)) @@ -79,29 +92,46 @@ function BlockArrays.blocks(bt::AbstractBlockTuple) return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt)) end -# +# length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength")) + # ===================================== BlockedTuple ===================================== # -struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple +struct BlockedTuple{BlockLength,BlockLengths,Flat} <: AbstractBlockTuple{BlockLength} flat::Flat - function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths} - length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length")) - return new{BlockLengths,typeof(flat)}(flat) + function BlockedTuple{BlockLength,BlockLengths}( + flat::Tuple + ) where {BlockLength,BlockLengths} + length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength")) + length(flat) != sum(BlockLengths; init=0) && + throw(DimensionMismatch("Invalid total length")) + any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length")) + return new{BlockLength,BlockLengths,typeof(flat)}(flat) end end # TensorAlgebra Interface -tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples(tt)) +function tuplemortar(tt::Tuple{Vararg{Tuple}}) + return BlockedTuple{length(tt),length.(tt)}(flatten_tuples(tt)) +end function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}}) - return BlockedTuple{BlockLengths}(flat) + return BlockedTuple{length(BlockLengths),BlockLengths}(flat) +end +function BlockedTuple(flat::Tuple, ::Val{BlockLengths}) where {BlockLengths} + # use Val to preserve compile time knowledge of BL + return BlockedTuple{length(BlockLengths),BlockLengths}(flat) +end +function BlockedTuple(bt::AbstractBlockTuple) + bl = blocklengths(bt) + return BlockedTuple{length(bl),bl}(Tuple(bt)) end -BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt)) # Base interface Base.Tuple(bt::BlockedTuple) = bt.flat # BlockArrays interface -function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths} +function BlockArrays.blocklengths( + ::Type{<:BlockedTuple{<:Any,BlockLengths}} +) where {BlockLengths} return BlockLengths end diff --git a/src/fusedims.jl b/src/fusedims.jl index 35a2823..2e87346 100644 --- a/src/fusedims.jl +++ b/src/fusedims.jl @@ -51,13 +51,13 @@ function fusedims(a::AbstractArray, permblocks...) end function fuseaxes( - axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockedPermutation + axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation ) axesblocks = blockpermute(axes, blockedperm) return map(block -> ⊗(block...), axesblocks) end -function fuseaxes(a::AbstractArray, blockedperm::AbstractBlockedPermutation) +function fuseaxes(a::AbstractArray, blockedperm::AbstractBlockPermutation) return fuseaxes(axes(a), blockedperm) end diff --git a/test/Project.toml b/test/Project.toml index c13f5f9..60a147b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" diff --git a/test/test_basics.jl b/test/test_basics.jl index 35b35e0..5a5bda2 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,68 +1,10 @@ -using BlockArrays: blockfirsts, blocklasts, blocklength, blocklengths, blocks -using Combinatorics: permutations using EllipsisNotation: var".." using LinearAlgebra: norm, qr -using TensorAlgebra: TensorAlgebra, blockedperm, blockedperm_indexin, fusedims, splitdims +using TensorAlgebra: TensorAlgebra, fusedims, splitdims default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) using Test: @test, @test_broken, @testset const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) -@testset "BlockedPermutation" begin - p = blockedperm((3, 4, 5), (2, 1)) - @test Tuple(p) === (3, 4, 5, 2, 1) - @test isperm(p) - @test length(p) == 5 - @test blocks(p) == ((3, 4, 5), (2, 1)) - @test blocklength(p) == 2 - @test blocklengths(p) == (3, 2) - @test blockfirsts(p) == (1, 4) - @test blocklasts(p) == (3, 5) - @test invperm(p) == blockedperm((5, 4, 1), (2, 3)) - # Empty block. - p = blockedperm((3, 2), (), (1,)) - @test Tuple(p) === (3, 2, 1) - @test isperm(p) - @test length(p) == 3 - @test blocks(p) == ((3, 2), (), (1,)) - @test blocklength(p) == 3 - @test blocklengths(p) == (2, 0, 1) - @test blockfirsts(p) == (1, 3, 3) - @test blocklasts(p) == (2, 2, 3) - @test invperm(p) == blockedperm((3, 2), (), (1,)) - - # Split collection into `BlockedPermutation`. - p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d")) - @test p == blockedperm((3, 1), (2, 4)) - - # Singleton dimensions. - p = blockedperm((2, 3), 1) - @test p == blockedperm((2, 3), (1,)) - - # First dimensions are unspecified. - p = blockedperm(.., (4, 3)) - @test p == blockedperm(1, 2, (4, 3)) - # Specify length - p = blockedperm(.., (4, 3); length=Val(6)) - @test p == blockedperm(1, 2, 5, 6, (4, 3)) - - # Last dimensions are unspecified. - p = blockedperm((4, 3), ..) - @test p == blockedperm((4, 3), 1, 2) - # Specify length - p = blockedperm((4, 3), ..; length=Val(6)) - @test p == blockedperm((4, 3), 1, 2, 5, 6) - - # Middle dimensions are unspecified. - p = blockedperm((4, 3), .., 1) - @test p == blockedperm((4, 3), 2, 1) - # Specify length - p = blockedperm((4, 3), .., 1; length=Val(6)) - @test p == blockedperm((4, 3), 2, 5, 6, 1) - - # No dimensions are unspecified. - p = blockedperm((3, 2), .., 1) - @test p == blockedperm((3, 2), 1) -end @testset "TensorAlgebra" begin @testset "fusedims (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl new file mode 100644 index 0000000..37f8e6a --- /dev/null +++ b/test/test_blockedpermutation.jl @@ -0,0 +1,124 @@ +using Test: @test, @test_broken, @testset + +using BlockArrays: blockfirsts, blocklasts, blocklength, blocklengths, blocks +using EllipsisNotation: var".." +using TestExtras: @constinferred + +using TensorAlgebra: + BlockedPermutation, + BlockedTrivialPermutation, + BlockedTuple, + blockedperm, + blockedperm_indexin, + blockedtrivialperm, + trivialperm + +@testset "BlockedPermutation" begin + p = @constinferred blockedperm((3, 4, 5), (2, 1)) + @test Tuple(p) === (3, 4, 5, 2, 1) + @test isperm(p) + @test length(p) == 5 + @test blocks(p) == ((3, 4, 5), (2, 1)) + @test blocklength(p) == 2 + @test blocklengths(p) == (3, 2) + @test blockfirsts(p) == (1, 4) + @test blocklasts(p) == (3, 5) + @test (@constinferred invperm(p)) == blockedperm((5, 4, 1), (2, 3)) + @test p isa BlockedPermutation{2} + + flat = (3, 4, 5, 2, 1) + @test_throws DimensionMismatch BlockedPermutation{2,(1, 2, 2)}(flat) + @test_throws DimensionMismatch BlockedPermutation{3,(1, 2, 3)}(flat) + @test_throws DimensionMismatch BlockedPermutation{3,(-1, 3, 3)}(flat) + + # Empty block. + p = @constinferred blockedperm((3, 2), (), (1,)) + @test Tuple(p) === (3, 2, 1) + @test isperm(p) + @test length(p) == 3 + @test blocks(p) == ((3, 2), (), (1,)) + @test blocklength(p) == 3 + @test blocklengths(p) == (2, 0, 1) + @test blockfirsts(p) == (1, 3, 3) + @test blocklasts(p) == (2, 2, 3) + @test invperm(p) == blockedperm((3, 2), (), (1,)) + @test p isa BlockedPermutation{3} + + p = @constinferred blockedperm((), ()) + @test Tuple(p) === () + @test blocklength(p) == 2 + @test blocklengths(p) == (0, 0) + @test isperm(p) + @test length(p) == 0 + @test blocks(p) == ((), ()) + @test p isa BlockedPermutation{2} + + p = @constinferred blockedperm() + @test Tuple(p) === () + @test blocklength(p) == 0 + @test blocklengths(p) == () + @test isperm(p) + @test length(p) == 0 + @test blocks(p) == () + @test p isa BlockedPermutation{0} + + p = blockedperm((3, 2), (), (1,)) + bt = BlockedTuple{3,(2, 0, 1)}((3, 2, 1)) + @test (@constinferred BlockedTuple(p)) == bt + @test (@constinferred map(identity, p)) == bt + @test (@constinferred p .+ p) == BlockedTuple{3,(2, 0, 1)}((6, 4, 2)) + @test (@constinferred blockedperm(p)) == p + @test (@constinferred blockedperm(bt)) == p + + # Split collection into `BlockedPermutation`. + p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d")) + @test p == blockedperm((3, 1), (2, 4)) + + # Singleton dimensions. + p = @constinferred blockedperm((2, 3), 1) + @test p == blockedperm((2, 3), (1,)) + + # First dimensions are unspecified. + p = blockedperm(.., (4, 3)) + @test p == blockedperm(1, 2, (4, 3)) + # Specify length + p = blockedperm(.., (4, 3); length=Val(6)) + @test p == blockedperm(1, 2, 5, 6, (4, 3)) + + # Last dimensions are unspecified. + p = blockedperm((4, 3), ..) + @test p == blockedperm((4, 3), 1, 2) + # Specify length + p = blockedperm((4, 3), ..; length=Val(6)) + @test p == blockedperm((4, 3), 1, 2, 5, 6) + + # Middle dimensions are unspecified. + p = blockedperm((4, 3), .., 1) + @test p == blockedperm((4, 3), 2, 1) + # Specify length + p = blockedperm((4, 3), .., 1; length=Val(6)) + @test p == blockedperm((4, 3), 2, 5, 6, 1) + + # No dimensions are unspecified. + p = blockedperm((3, 2), .., 1) + @test p == blockedperm((3, 2), 1) +end + +@testset "BlockedTrivialPermutation" begin + tp = blockedtrivialperm((2, 0, 1)) + + @test tp isa BlockedTrivialPermutation{3} + @test Tuple(tp) == (1, 2, 3) + @test blocklength(tp) == 3 + @test blocklengths(tp) == (2, 0, 1) + @test trivialperm(blockedperm((3, 2), (), (1,))) == tp + + bt = BlockedTuple{3,(2, 0, 1)}((1, 2, 3)) + @test (@constinferred BlockedTuple(tp)) == bt + @test (@constinferred blocks(tp)) == blocks(bt) + @test (@constinferred map(identity, tp)) == bt + @test (@constinferred tp .+ tp) == BlockedTuple{3,(2, 0, 1)}((2, 4, 6)) + @test (@constinferred blockedperm(tp)) == tp + @test (@constinferred trivialperm(tp)) == tp + @test (@constinferred trivialperm(bt)) == tp +end diff --git a/test/test_blockedtuple.jl b/test/test_blockedtuple.jl index fd599ce..6083097 100644 --- a/test/test_blockedtuple.jl +++ b/test/test_blockedtuple.jl @@ -3,24 +3,27 @@ using Test: @test, @test_throws using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks using TestExtras: @constinferred -using TensorAlgebra: BlockedTuple, tuplemortar +using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar @testset "BlockedTuple" begin flat = (true, 'a', 2, "b", 3.0) divs = (1, 2, 2) - bt = BlockedTuple{divs}(flat) + bt = @constinferred BlockedTuple{3,divs}(flat) + @test bt isa BlockedTuple{3} + @test (@constinferred blockeachindex(bt)) == (Block(1), Block(2), Block(3)) @test (@constinferred Tuple(bt)) == flat - @test bt == tuplemortar(((true,), ('a', 2), ("b", 3.0))) - @test bt == BlockedTuple(flat, divs) - @test BlockedTuple(bt) == bt + @test (@constinferred tuplemortar(((true,), ('a', 2), ("b", 3.0)))) == bt + @test BlockedTuple(flat, divs) == bt + @test (@constinferred BlockedTuple(bt)) == bt @test blocklength(bt) == 3 @test blocklengths(bt) == (1, 2, 2) @test (@constinferred blocks(bt)) == ((true,), ('a', 2), ("b", 3.0)) @test (@constinferred bt[1]) == true @test (@constinferred bt[2]) == 'a' + @test (@constinferred map(identity, bt)) == bt # it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block @test bt[Block(1)] == blocks(bt)[1] @@ -37,19 +40,41 @@ using TensorAlgebra: BlockedTuple, tuplemortar @test iterate(bt, 2) == ('a', 3) @test blockisequal(only(axes(bt)), blockedrange([1, 2, 2])) - @test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat) + @test_throws DimensionMismatch BlockedTuple{2,(1, 2, 2)}(flat) + @test_throws DimensionMismatch BlockedTuple{3,(1, 2, 3)}(flat) + @test_throws DimensionMismatch BlockedTuple{3,(-1, 3, 3)}(flat) bt = tuplemortar(((1,), (4, 2), (5, 3))) + @test bt isa BlockedTuple @test Tuple(bt) == (1, 4, 2, 5, 3) @test blocklengths(bt) == (1, 2, 2) - @test deepcopy(bt) == bt + @test (@constinferred deepcopy(bt)) == bt @test (@constinferred map(n -> n + 1, bt)) == - BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1) - @test bt .+ tuplemortar(((1,), (1, 1), (1, 1))) == - BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1) + BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1) + @test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) == + BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1) @test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,))) bt = tuplemortar(((1:2, 1:2), (1:3,))) @test length.(bt) == tuplemortar(((2, 2), (3,))) + + # empty blocks + bt = tuplemortar(((1,), (), (5, 3))) + @test bt isa BlockedTuple{3} + @test Tuple(bt) == (1, 5, 3) + @test blocklengths(bt) == (1, 0, 2) + @test (@constinferred blocks(bt)) == ((1,), (), (5, 3)) + + bt = tuplemortar(((), ())) + @test bt isa BlockedTuple{2} + @test Tuple(bt) == () + @test blocklengths(bt) == (0, 0) + @test (@constinferred blocks(bt)) == ((), ()) + + bt = tuplemortar(()) + @test bt isa BlockedTuple{0} + @test Tuple(bt) == () + @test blocklengths(bt) == () + @test (@constinferred blocks(bt)) == () end