diff --git a/Project.toml b/Project.toml index 5d3d28f..bbd6420 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.2.7" +version = "0.2.8" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/blockedtuple.jl b/src/blockedtuple.jl index 5cf0c9a..30db0bd 100644 --- a/src/blockedtuple.jl +++ b/src/blockedtuple.jl @@ -4,13 +4,13 @@ using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange -using TypeParameterAccessors: unspecify_type_parameters +using TypeParameterAccessors: type_parameters, unspecify_type_parameters # # ================================== AbstractBlockTuple ================================== # # AbstractBlockTuple imposes BlockLength as first type parameter for easy dispatch -# it makes no assumotion on storage type +# it makes no assumption on storage type abstract type AbstractBlockTuple{BlockLength} end constructorof(type::Type{<:AbstractBlockTuple}) = unspecify_type_parameters(type) @@ -23,6 +23,7 @@ end # Base interface Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),) +Base.axes(::AbstractBlockTuple{0}) = (blockedrange(Int[]),) Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt) @@ -70,12 +71,62 @@ function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple}) return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}() end -# BroadcastStyle is not called for two identical styles +# default +combine_types(::Type{<:AbstractBlockTuple}, ::Type{<:AbstractBlockTuple}) = BlockedTuple + +# BroadcastStyle(::Style1, ::Style2) is not called when Style1 == Style2 +# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,))) = tuplemortar(((true,), (true,))) +# tuplemortar(((1,), (2,))) .== tuplemortar(((1, 2),)) = tuplemortar(((true,), (true,))) +# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,), (3,))) = error DimensionMismatch function Base.BroadcastStyle( - ::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle + s1::AbstractBlockTupleBroadcastStyle, s2::AbstractBlockTupleBroadcastStyle ) - throw(DimensionMismatch("Incompatible blocks")) + blocklengths1 = type_parameters(s1, 1) + blocklengths2 = type_parameters(s2, 1) + sum(blocklengths1; init=0) != sum(blocklengths2; init=0) && + throw(DimensionMismatch("blocked tuples could not be broadcast to a common size")) + new_blocklasts = static_mergesort(cumsum(blocklengths1), cumsum(blocklengths2)) + new_blocklengths = ( + first(new_blocklasts), Base.tail(new_blocklasts) .- Base.front(new_blocklasts)... + ) + BT = combine_types(type_parameters(s1, 2), type_parameters(s2, 2)) + return AbstractBlockTupleBroadcastStyle{new_blocklengths,BT}() +end + +static_mergesort(::Tuple{}, ::Tuple{}) = () +static_mergesort(a::Tuple, ::Tuple{}) = a +static_mergesort(::Tuple{}, b::Tuple) = b +function static_mergesort(a::Tuple, b::Tuple) + if first(a) == first(b) + return (first(a), static_mergesort(Base.tail(a), Base.tail(b))...) + end + if first(a) < first(b) + return (first(a), static_mergesort(Base.tail(a), b)...) + end + return (first(b), static_mergesort(a, Base.tail(b))...) end + +# tuplemortar(((1,), (2,))) .== (1, 2) = (true, true) +function Base.BroadcastStyle( + s::AbstractBlockTupleBroadcastStyle, ::Base.Broadcast.Style{Tuple} +) + return s +end + +# tuplemortar(((1,), (2,))) .== 1 = (true, false) +function Base.BroadcastStyle( + ::Base.Broadcast.DefaultArrayStyle{0}, s::AbstractBlockTupleBroadcastStyle +) + return s +end + +# tuplemortar(((1,), (2,))) .== [1, 1] = BlockVector([true, false], [1, 1]) +function Base.BroadcastStyle( + a::Base.Broadcast.AbstractArrayStyle, ::AbstractBlockTupleBroadcastStyle +) + return a +end + function Base.copy( bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}} ) where {BlockLengths,BT} @@ -104,8 +155,6 @@ function BlockArrays.blocks(bt::AbstractBlockTuple) return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt)) end -# length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength")) - # ===================================== BlockedTuple ===================================== # struct BlockedTuple{BlockLength,BlockLengths,Flat} <: AbstractBlockTuple{BlockLength} diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index e411540..e9198a9 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -77,6 +77,8 @@ using TensorAlgebra: @test (@constinferred BlockedTuple(p)) == bt @test (@constinferred map(identity, p)) == bt @test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,))) + @test (@constinferred p .+ bt) == tuplemortar(((6, 4), (), (2,))) + @test (@constinferred bt .+ p) == tuplemortar(((6, 4), (), (2,))) @test (@constinferred blockedperm(p)) == p @test (@constinferred blockedperm(bt)) == p @@ -149,6 +151,8 @@ end @test (@constinferred blocks(tp)) == blocks(bt) @test (@constinferred map(identity, tp)) == bt @test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,))) + @test (@constinferred tp .+ Tuple(tp)) == tuplemortar(((2, 4), (), (6,))) + @test (@constinferred tp .+ BlockedTuple(tp)) == tuplemortar(((2, 4), (), (6,))) @test (@constinferred blockedperm(tp)) == tp @test (@constinferred trivialperm(tp)) == tp @test (@constinferred trivialperm(bt)) == tp diff --git a/test/test_blockedtuple.jl b/test/test_blockedtuple.jl index 5973f6f..b19b8ab 100644 --- a/test/test_blockedtuple.jl +++ b/test/test_blockedtuple.jl @@ -1,6 +1,7 @@ -using Test: @test, @test_throws +using Test: @test, @test_throws, @testset -using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks +using BlockArrays: + Block, BlockVector, blocklength, blocklengths, blockedrange, blockisequal, blocks using TestExtras: @constinferred using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar @@ -54,28 +55,70 @@ using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1) @test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) == BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1) - @test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,))) + @test (@constinferred bt .+ tuplemortar(((1,), (1, 1, 1), (1,)))) isa + BlockedTuple{4,(1, 2, 1, 1),NTuple{5,Int64}} + @test bt .+ tuplemortar(((1,), (1, 1, 1), (1,))) == + tuplemortar(((2,), (5, 3), (6,), (4,))) bt = tuplemortar(((1:2, 1:2), (1:3,))) @test length.(bt) == tuplemortar(((2, 2), (3,))) @test length.(length.(bt)) == tuplemortar(((1, 1), (1,))) + bt = tuplemortar(((1,), (2,))) + @test (@constinferred bt .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} + @test (bt .== bt) == tuplemortar(((true,), (true,))) + @test (@constinferred bt .== tuplemortar(((1, 2),))) isa + BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} + @test (bt .== tuplemortar(((1, 2),))) == tuplemortar(((true,), (true,))) + @test_throws DimensionMismatch bt .== tuplemortar(((1,), (2,), (3,))) + @test (@constinferred bt .== (1, 2)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} + @test (bt .== (1, 2)) == tuplemortar(((true,), (true,))) + @test_throws DimensionMismatch bt .== (1, 2, 3) + @test (@constinferred bt .== 1) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} + @test (bt .== 1) == tuplemortar(((true,), (false,))) + @test (@constinferred bt .== (1,)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} + + @test (bt .== (1,)) == tuplemortar(((true,), (false,))) + # BlockedTuple .== AbstractVector is not type stable. Requires fix in BlockArrays + @test (bt .== [1, 1]) isa BlockVector{Bool} + @test blocks(bt .== [1, 1]) == [[true], [false]] + @test_throws DimensionMismatch bt .== [1, 2, 3] + + @test (@constinferred (1, 2) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} + @test ((1, 2) .== bt) == tuplemortar(((true,), (true,))) + @test_throws DimensionMismatch (1, 2, 3) .== bt + @test (@constinferred 1 .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} + @test (1 .== bt) == tuplemortar(((true,), (false,))) + @test (@constinferred (1,) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}} + @test ((1,) .== bt) == tuplemortar(((true,), (false,))) + @test ([1, 1] .== bt) isa BlockVector{Bool} + @test blocks([1, 1] .== bt) == [[true], [false]] + # empty blocks bt = tuplemortar(((1,), (), (5, 3))) @test bt isa BlockedTuple{3} @test Tuple(bt) == (1, 5, 3) @test blocklengths(bt) == (1, 0, 2) @test (@constinferred blocks(bt)) == ((1,), (), (5, 3)) + @test blockisequal(only(axes(bt)), blockedrange([1, 0, 2])) bt = tuplemortar(((), ())) @test bt isa BlockedTuple{2} @test Tuple(bt) == () @test blocklengths(bt) == (0, 0) @test (@constinferred blocks(bt)) == ((), ()) - - bt = tuplemortar(()) - @test bt isa BlockedTuple{0} - @test Tuple(bt) == () - @test blocklengths(bt) == () - @test (@constinferred blocks(bt)) == () + @test blockisequal(only(axes(bt)), blockedrange([0, 0])) + @test bt == bt .+ bt + + bt0 = tuplemortar(()) + bt1 = tuplemortar(((),)) + @test bt0 isa BlockedTuple{0} + @test Tuple(bt0) == () + @test blocklengths(bt0) == () + @test (@constinferred blocks(bt0)) == () + @test blockisequal(only(axes(bt0)), blockedrange(zeros(Int, 0))) + @test bt0 == bt0 + @test bt != bt1 + @test (@constinferred bt0 .+ bt0) == bt0 + @test (@constinferred bt0 .+ bt1) == bt1 end