Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.7"
version = "0.2.8"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
63 changes: 56 additions & 7 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange

using TypeParameterAccessors: unspecify_type_parameters
using TypeParameterAccessors: type_parameters, unspecify_type_parameters

#
# ================================== AbstractBlockTuple ==================================
#
# AbstractBlockTuple imposes BlockLength as first type parameter for easy dispatch
# it makes no assumotion on storage type
# it makes no assumption on storage type
abstract type AbstractBlockTuple{BlockLength} end

constructorof(type::Type{<:AbstractBlockTuple}) = unspecify_type_parameters(type)
Expand All @@ -23,6 +23,7 @@

# Base interface
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)
Base.axes(::AbstractBlockTuple{0}) = (blockedrange(Int[]),)

Check warning on line 26 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L26

Added line #L26 was not covered by tests

Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt)

Expand Down Expand Up @@ -70,12 +71,62 @@
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}()
end

# BroadcastStyle is not called for two identical styles
# default
combine_types(::Type{<:AbstractBlockTuple}, ::Type{<:AbstractBlockTuple}) = BlockedTuple

Check warning on line 75 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L75

Added line #L75 was not covered by tests

# BroadcastStyle(::Style1, ::Style2) is not called when Style1 == Style2
# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,))) = tuplemortar(((true,), (true,)))
# tuplemortar(((1,), (2,))) .== tuplemortar(((1, 2),)) = tuplemortar(((true,), (true,)))
# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,), (3,))) = error DimensionMismatch
function Base.BroadcastStyle(
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle
s1::AbstractBlockTupleBroadcastStyle, s2::AbstractBlockTupleBroadcastStyle
)
throw(DimensionMismatch("Incompatible blocks"))
blocklengths1 = type_parameters(s1, 1)
blocklengths2 = type_parameters(s2, 1)
sum(blocklengths1; init=0) != sum(blocklengths2; init=0) &&

Check warning on line 86 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L84-L86

Added lines #L84 - L86 were not covered by tests
throw(DimensionMismatch("blocked tuples could not be broadcast to a common size"))
new_blocklasts = static_mergesort(cumsum(blocklengths1), cumsum(blocklengths2))
new_blocklengths = (

Check warning on line 89 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L88-L89

Added lines #L88 - L89 were not covered by tests
first(new_blocklasts), Base.tail(new_blocklasts) .- Base.front(new_blocklasts)...
)
BT = combine_types(type_parameters(s1, 2), type_parameters(s2, 2))
return AbstractBlockTupleBroadcastStyle{new_blocklengths,BT}()

Check warning on line 93 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L92-L93

Added lines #L92 - L93 were not covered by tests
end

static_mergesort(::Tuple{}, ::Tuple{}) = ()
static_mergesort(a::Tuple, ::Tuple{}) = a
static_mergesort(::Tuple{}, b::Tuple) = b
function static_mergesort(a::Tuple, b::Tuple)
if first(a) == first(b)
return (first(a), static_mergesort(Base.tail(a), Base.tail(b))...)

Check warning on line 101 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L96-L101

Added lines #L96 - L101 were not covered by tests
end
if first(a) < first(b)
return (first(a), static_mergesort(Base.tail(a), b)...)

Check warning on line 104 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
end
return (first(b), static_mergesort(a, Base.tail(b))...)

Check warning on line 106 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L106

Added line #L106 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== (1, 2) = (true, true)
function Base.BroadcastStyle(

Check warning on line 110 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L110

Added line #L110 was not covered by tests
s::AbstractBlockTupleBroadcastStyle, ::Base.Broadcast.Style{Tuple}
)
return s

Check warning on line 113 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L113

Added line #L113 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== 1 = (true, false)
function Base.BroadcastStyle(

Check warning on line 117 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L117

Added line #L117 was not covered by tests
::Base.Broadcast.DefaultArrayStyle{0}, s::AbstractBlockTupleBroadcastStyle
)
return s

Check warning on line 120 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L120

Added line #L120 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== [1, 1] = BlockVector([true, false], [1, 1])
function Base.BroadcastStyle(

Check warning on line 124 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L124

Added line #L124 was not covered by tests
a::Base.Broadcast.AbstractArrayStyle, ::AbstractBlockTupleBroadcastStyle
)
return a

Check warning on line 127 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L127

Added line #L127 was not covered by tests
end

function Base.copy(
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
) where {BlockLengths,BT}
Expand Down Expand Up @@ -104,8 +155,6 @@
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
end

# length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength"))

# ===================================== BlockedTuple =====================================
#
struct BlockedTuple{BlockLength,BlockLengths,Flat} <: AbstractBlockTuple{BlockLength}
Expand Down
4 changes: 4 additions & 0 deletions test/test_blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ using TensorAlgebra:
@test (@constinferred BlockedTuple(p)) == bt
@test (@constinferred map(identity, p)) == bt
@test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,)))
@test (@constinferred p .+ bt) == tuplemortar(((6, 4), (), (2,)))
@test (@constinferred bt .+ p) == tuplemortar(((6, 4), (), (2,)))
@test (@constinferred blockedperm(p)) == p
@test (@constinferred blockedperm(bt)) == p

Expand Down Expand Up @@ -149,6 +151,8 @@ end
@test (@constinferred blocks(tp)) == blocks(bt)
@test (@constinferred map(identity, tp)) == bt
@test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,)))
@test (@constinferred tp .+ Tuple(tp)) == tuplemortar(((2, 4), (), (6,)))
@test (@constinferred tp .+ BlockedTuple(tp)) == tuplemortar(((2, 4), (), (6,)))
@test (@constinferred blockedperm(tp)) == tp
@test (@constinferred trivialperm(tp)) == tp
@test (@constinferred trivialperm(bt)) == tp
Expand Down
61 changes: 52 additions & 9 deletions test/test_blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test: @test, @test_throws
using Test: @test, @test_throws, @testset

using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks
using BlockArrays:
Block, BlockVector, blocklength, blocklengths, blockedrange, blockisequal, blocks
using TestExtras: @constinferred

using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar
Expand Down Expand Up @@ -54,28 +55,70 @@ using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar
BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1)
@test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) ==
BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1)
@test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,)))
@test (@constinferred bt .+ tuplemortar(((1,), (1, 1, 1), (1,)))) isa
BlockedTuple{4,(1, 2, 1, 1),NTuple{5,Int64}}
@test bt .+ tuplemortar(((1,), (1, 1, 1), (1,))) ==
tuplemortar(((2,), (5, 3), (6,), (4,)))

bt = tuplemortar(((1:2, 1:2), (1:3,)))
@test length.(bt) == tuplemortar(((2, 2), (3,)))
@test length.(length.(bt)) == tuplemortar(((1, 1), (1,)))

bt = tuplemortar(((1,), (2,)))
@test (@constinferred bt .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== bt) == tuplemortar(((true,), (true,)))
@test (@constinferred bt .== tuplemortar(((1, 2),))) isa
BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== tuplemortar(((1, 2),))) == tuplemortar(((true,), (true,)))
@test_throws DimensionMismatch bt .== tuplemortar(((1,), (2,), (3,)))
@test (@constinferred bt .== (1, 2)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== (1, 2)) == tuplemortar(((true,), (true,)))
@test_throws DimensionMismatch bt .== (1, 2, 3)
@test (@constinferred bt .== 1) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== 1) == tuplemortar(((true,), (false,)))
@test (@constinferred bt .== (1,)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}

@test (bt .== (1,)) == tuplemortar(((true,), (false,)))
# BlockedTuple .== AbstractVector is not type stable. Requires fix in BlockArrays
@test (bt .== [1, 1]) isa BlockVector{Bool}
@test blocks(bt .== [1, 1]) == [[true], [false]]
@test_throws DimensionMismatch bt .== [1, 2, 3]

@test (@constinferred (1, 2) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test ((1, 2) .== bt) == tuplemortar(((true,), (true,)))
@test_throws DimensionMismatch (1, 2, 3) .== bt
@test (@constinferred 1 .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (1 .== bt) == tuplemortar(((true,), (false,)))
@test (@constinferred (1,) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test ((1,) .== bt) == tuplemortar(((true,), (false,)))
@test ([1, 1] .== bt) isa BlockVector{Bool}
@test blocks([1, 1] .== bt) == [[true], [false]]

# empty blocks
bt = tuplemortar(((1,), (), (5, 3)))
@test bt isa BlockedTuple{3}
@test Tuple(bt) == (1, 5, 3)
@test blocklengths(bt) == (1, 0, 2)
@test (@constinferred blocks(bt)) == ((1,), (), (5, 3))
@test blockisequal(only(axes(bt)), blockedrange([1, 0, 2]))

bt = tuplemortar(((), ()))
@test bt isa BlockedTuple{2}
@test Tuple(bt) == ()
@test blocklengths(bt) == (0, 0)
@test (@constinferred blocks(bt)) == ((), ())

bt = tuplemortar(())
@test bt isa BlockedTuple{0}
@test Tuple(bt) == ()
@test blocklengths(bt) == ()
@test (@constinferred blocks(bt)) == ()
@test blockisequal(only(axes(bt)), blockedrange([0, 0]))
@test bt == bt .+ bt

bt0 = tuplemortar(())
bt1 = tuplemortar(((),))
@test bt0 isa BlockedTuple{0}
@test Tuple(bt0) == ()
@test blocklengths(bt0) == ()
@test (@constinferred blocks(bt0)) == ()
@test blockisequal(only(axes(bt0)), blockedrange(zeros(Int, 0)))
@test bt0 == bt0
@test bt != bt1
@test (@constinferred bt0 .+ bt0) == bt0
@test (@constinferred bt0 .+ bt1) == bt1
end
Loading