Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.7"
version = "0.2.8"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
63 changes: 56 additions & 7 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange

using TypeParameterAccessors: unspecify_type_parameters
using TypeParameterAccessors: type_parameters, unspecify_type_parameters

#
# ================================== AbstractBlockTuple ==================================
#
# AbstractBlockTuple imposes BlockLength as first type parameter for easy dispatch
# it makes no assumotion on storage type
# it makes no assumption on storage type
abstract type AbstractBlockTuple{BlockLength} end

constructorof(type::Type{<:AbstractBlockTuple}) = unspecify_type_parameters(type)
Expand All @@ -23,6 +23,7 @@

# Base interface
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)
Base.axes(::AbstractBlockTuple{0}) = (blockedrange(zeros(Int, 0)),)

Check warning on line 26 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L26

Added line #L26 was not covered by tests

Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt)

Expand Down Expand Up @@ -70,12 +71,62 @@
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}()
end

# BroadcastStyle is not called for two identical styles
# default
combine_types(::Type{<:AbstractBlockTuple}, ::Type{<:AbstractBlockTuple}) = BlockedTuple

Check warning on line 75 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L75

Added line #L75 was not covered by tests

# BroadcastStyle(::Style1, ::Style2) is not called when Style1 == Style2
# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,))) = tuplemortar(((true,), (true,)))
# tuplemortar(((1,), (2,))) .== tuplemortar(((1, 2),)) = tuplemortar(((true,), (true,)))
# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,), (3,))) = error DimensionMismatch
function Base.BroadcastStyle(
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle
s1::AbstractBlockTupleBroadcastStyle, s2::AbstractBlockTupleBroadcastStyle
)
throw(DimensionMismatch("Incompatible blocks"))
blocklengths1 = type_parameters(s1, 1)
blocklengths2 = type_parameters(s2, 1)
sum(blocklengths1; init=0) != sum(blocklengths2; init=0) &&

Check warning on line 86 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L84-L86

Added lines #L84 - L86 were not covered by tests
throw(DimensionMismatch("blocked tuples could not be broadcast to a common size"))
new_blocklasts = static_mergesort(cumsum(blocklengths1), cumsum(blocklengths2))
new_blocklengths = (

Check warning on line 89 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L88-L89

Added lines #L88 - L89 were not covered by tests
first(new_blocklasts), Base.tail(new_blocklasts) .- Base.front(new_blocklasts)...
)
BT = combine_types(type_parameters(s1, 2), type_parameters(s2, 2))
return AbstractBlockTupleBroadcastStyle{new_blocklengths,BT}()

Check warning on line 93 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L92-L93

Added lines #L92 - L93 were not covered by tests
end

static_mergesort(::Tuple{}, ::Tuple{}) = ()
static_mergesort(a::Tuple, ::Tuple{}) = a
static_mergesort(::Tuple{}, b::Tuple) = b
function static_mergesort(a::Tuple, b::Tuple)
if first(a) == first(b)
return (first(a), static_mergesort(Base.tail(a), Base.tail(b))...)

Check warning on line 101 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L96-L101

Added lines #L96 - L101 were not covered by tests
end
if first(a) < first(b)
return (first(a), static_mergesort(Base.tail(a), b)...)

Check warning on line 104 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
end
return (first(b), static_mergesort(a, Base.tail(b))...)

Check warning on line 106 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L106

Added line #L106 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== (1, 2) = (true, true)
function Base.BroadcastStyle(

Check warning on line 110 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L110

Added line #L110 was not covered by tests
s::AbstractBlockTupleBroadcastStyle, ::Base.Broadcast.Style{Tuple}
)
return s

Check warning on line 113 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L113

Added line #L113 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== 1 = (true, false)
function Base.BroadcastStyle(

Check warning on line 117 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L117

Added line #L117 was not covered by tests
::Base.Broadcast.DefaultArrayStyle{0}, s::AbstractBlockTupleBroadcastStyle
)
return s

Check warning on line 120 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L120

Added line #L120 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== [1, 1] = BlockVector([true, false], [1, 1])
function Base.BroadcastStyle(

Check warning on line 124 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L124

Added line #L124 was not covered by tests
a::Base.Broadcast.AbstractArrayStyle, ::AbstractBlockTupleBroadcastStyle
)
return a

Check warning on line 127 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L127

Added line #L127 was not covered by tests
end

function Base.copy(
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
) where {BlockLengths,BT}
Expand Down Expand Up @@ -104,8 +155,6 @@
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
end

# length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength"))

# ===================================== BlockedTuple =====================================
#
struct BlockedTuple{BlockLength,BlockLengths,Flat} <: AbstractBlockTuple{BlockLength}
Expand Down
4 changes: 4 additions & 0 deletions test/test_blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ using TensorAlgebra:
@test (@constinferred BlockedTuple(p)) == bt
@test (@constinferred map(identity, p)) == bt
@test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,)))
@test (@constinferred p .+ bt) == tuplemortar(((6, 4), (), (2,)))
@test (@constinferred bt .+ p) == tuplemortar(((6, 4), (), (2,)))
@test (@constinferred blockedperm(p)) == p
@test (@constinferred blockedperm(bt)) == p

Expand Down Expand Up @@ -149,6 +151,8 @@ end
@test (@constinferred blocks(tp)) == blocks(bt)
@test (@constinferred map(identity, tp)) == bt
@test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,)))
@test (@constinferred tp .+ Tuple(tp)) == tuplemortar(((2, 4), (), (6,)))
@test (@constinferred tp .+ BlockedTuple(tp)) == tuplemortar(((2, 4), (), (6,)))
@test (@constinferred blockedperm(tp)) == tp
@test (@constinferred trivialperm(tp)) == tp
@test (@constinferred trivialperm(bt)) == tp
Expand Down
61 changes: 52 additions & 9 deletions test/test_blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test: @test, @test_throws
using Test: @test, @test_throws, @testset

using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks
using BlockArrays:
Block, BlockVector, blocklength, blocklengths, blockedrange, blockisequal, blocks
using TestExtras: @constinferred

using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar
Expand Down Expand Up @@ -54,28 +55,70 @@ using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar
BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1)
@test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) ==
BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1)
@test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,)))
@test (@constinferred bt .+ tuplemortar(((1,), (1, 1, 1), (1,)))) isa
BlockedTuple{4,(1, 2, 1, 1),NTuple{5,Int64}}
@test bt .+ tuplemortar(((1,), (1, 1, 1), (1,))) ==
tuplemortar(((2,), (5, 3), (6,), (4,)))

bt = tuplemortar(((1:2, 1:2), (1:3,)))
@test length.(bt) == tuplemortar(((2, 2), (3,)))
@test length.(length.(bt)) == tuplemortar(((1, 1), (1,)))

bt = tuplemortar(((1,), (2,)))
@test (@constinferred bt .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== bt) == tuplemortar(((true,), (true,)))
@test (@constinferred bt .== tuplemortar(((1, 2),))) isa
BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== tuplemortar(((1, 2),))) == tuplemortar(((true,), (true,)))
@test_throws DimensionMismatch bt .== tuplemortar(((1,), (2,), (3,)))
@test (@constinferred bt .== (1, 2)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== (1, 2)) == tuplemortar(((true,), (true,)))
@test_throws DimensionMismatch bt .== (1, 2, 3)
@test (@constinferred bt .== 1) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== 1) == tuplemortar(((true,), (false,)))
@test (@constinferred bt .== (1,)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}

@test (bt .== (1,)) == tuplemortar(((true,), (false,)))
# BlockedTuple .== AbstractVector is not type stable. Requires fix in BlockArrays
@test (bt .== [1, 1]) isa BlockVector{Bool}
@test blocks(bt .== [1, 1]) == [[true], [false]]
@test_throws DimensionMismatch bt .== [1, 2, 3]

@test (@constinferred (1, 2) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test ((1, 2) .== bt) == tuplemortar(((true,), (true,)))
@test_throws DimensionMismatch (1, 2, 3) .== bt
@test (@constinferred 1 .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (1 .== bt) == tuplemortar(((true,), (false,)))
@test (@constinferred (1,) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test ((1,) .== bt) == tuplemortar(((true,), (false,)))
@test ([1, 1] .== bt) isa BlockVector{Bool}
@test blocks([1, 1] .== bt) == [[true], [false]]

# empty blocks
bt = tuplemortar(((1,), (), (5, 3)))
@test bt isa BlockedTuple{3}
@test Tuple(bt) == (1, 5, 3)
@test blocklengths(bt) == (1, 0, 2)
@test (@constinferred blocks(bt)) == ((1,), (), (5, 3))
@test blockisequal(only(axes(bt)), blockedrange([1, 0, 2]))

bt = tuplemortar(((), ()))
@test bt isa BlockedTuple{2}
@test Tuple(bt) == ()
@test blocklengths(bt) == (0, 0)
@test (@constinferred blocks(bt)) == ((), ())

bt = tuplemortar(())
@test bt isa BlockedTuple{0}
@test Tuple(bt) == ()
@test blocklengths(bt) == ()
@test (@constinferred blocks(bt)) == ()
@test blockisequal(only(axes(bt)), blockedrange([0, 0]))
@test bt == bt .+ bt

bt0 = tuplemortar(())
bt1 = tuplemortar(((),))
@test bt0 isa BlockedTuple{0}
@test Tuple(bt0) == ()
@test blocklengths(bt0) == ()
@test (@constinferred blocks(bt0)) == ()
@test blockisequal(only(axes(bt0)), blockedrange(zeros(Int, 0)))
@test bt0 == bt0
@test bt != bt1
@test (@constinferred bt0 .+ bt0) == bt0
@test (@constinferred bt0 .+ bt1) == bt1
end
Loading