Skip to content

Commit 4b0e372

Browse files
authored
more generic broadcast for AbstractBlockTuple (#46)
1 parent db42098 commit 4b0e372

File tree

4 files changed

+113
-17
lines changed

4 files changed

+113
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.7"
4+
version = "0.2.8"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/blockedtuple.jl

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange
66

7-
using TypeParameterAccessors: unspecify_type_parameters
7+
using TypeParameterAccessors: type_parameters, unspecify_type_parameters
88

99
#
1010
# ================================== AbstractBlockTuple ==================================
1111
#
1212
# AbstractBlockTuple imposes BlockLength as first type parameter for easy dispatch
13-
# it makes no assumotion on storage type
13+
# it makes no assumption on storage type
1414
abstract type AbstractBlockTuple{BlockLength} end
1515

1616
constructorof(type::Type{<:AbstractBlockTuple}) = unspecify_type_parameters(type)
@@ -23,6 +23,7 @@ end
2323

2424
# Base interface
2525
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)
26+
Base.axes(::AbstractBlockTuple{0}) = (blockedrange(Int[]),)
2627

2728
Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt)
2829

@@ -70,12 +71,62 @@ function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple})
7071
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}()
7172
end
7273

73-
# BroadcastStyle is not called for two identical styles
74+
# default
75+
combine_types(::Type{<:AbstractBlockTuple}, ::Type{<:AbstractBlockTuple}) = BlockedTuple
76+
77+
# BroadcastStyle(::Style1, ::Style2) is not called when Style1 == Style2
78+
# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,))) = tuplemortar(((true,), (true,)))
79+
# tuplemortar(((1,), (2,))) .== tuplemortar(((1, 2),)) = tuplemortar(((true,), (true,)))
80+
# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,), (3,))) = error DimensionMismatch
7481
function Base.BroadcastStyle(
75-
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle
82+
s1::AbstractBlockTupleBroadcastStyle, s2::AbstractBlockTupleBroadcastStyle
7683
)
77-
throw(DimensionMismatch("Incompatible blocks"))
84+
blocklengths1 = type_parameters(s1, 1)
85+
blocklengths2 = type_parameters(s2, 1)
86+
sum(blocklengths1; init=0) != sum(blocklengths2; init=0) &&
87+
throw(DimensionMismatch("blocked tuples could not be broadcast to a common size"))
88+
new_blocklasts = static_mergesort(cumsum(blocklengths1), cumsum(blocklengths2))
89+
new_blocklengths = (
90+
first(new_blocklasts), Base.tail(new_blocklasts) .- Base.front(new_blocklasts)...
91+
)
92+
BT = combine_types(type_parameters(s1, 2), type_parameters(s2, 2))
93+
return AbstractBlockTupleBroadcastStyle{new_blocklengths,BT}()
94+
end
95+
96+
static_mergesort(::Tuple{}, ::Tuple{}) = ()
97+
static_mergesort(a::Tuple, ::Tuple{}) = a
98+
static_mergesort(::Tuple{}, b::Tuple) = b
99+
function static_mergesort(a::Tuple, b::Tuple)
100+
if first(a) == first(b)
101+
return (first(a), static_mergesort(Base.tail(a), Base.tail(b))...)
102+
end
103+
if first(a) < first(b)
104+
return (first(a), static_mergesort(Base.tail(a), b)...)
105+
end
106+
return (first(b), static_mergesort(a, Base.tail(b))...)
78107
end
108+
109+
# tuplemortar(((1,), (2,))) .== (1, 2) = (true, true)
110+
function Base.BroadcastStyle(
111+
s::AbstractBlockTupleBroadcastStyle, ::Base.Broadcast.Style{Tuple}
112+
)
113+
return s
114+
end
115+
116+
# tuplemortar(((1,), (2,))) .== 1 = (true, false)
117+
function Base.BroadcastStyle(
118+
::Base.Broadcast.DefaultArrayStyle{0}, s::AbstractBlockTupleBroadcastStyle
119+
)
120+
return s
121+
end
122+
123+
# tuplemortar(((1,), (2,))) .== [1, 1] = BlockVector([true, false], [1, 1])
124+
function Base.BroadcastStyle(
125+
a::Base.Broadcast.AbstractArrayStyle, ::AbstractBlockTupleBroadcastStyle
126+
)
127+
return a
128+
end
129+
79130
function Base.copy(
80131
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
81132
) where {BlockLengths,BT}
@@ -104,8 +155,6 @@ function BlockArrays.blocks(bt::AbstractBlockTuple)
104155
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
105156
end
106157

107-
# length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength"))
108-
109158
# ===================================== BlockedTuple =====================================
110159
#
111160
struct BlockedTuple{BlockLength,BlockLengths,Flat} <: AbstractBlockTuple{BlockLength}

test/test_blockedpermutation.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ using TensorAlgebra:
7777
@test (@constinferred BlockedTuple(p)) == bt
7878
@test (@constinferred map(identity, p)) == bt
7979
@test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,)))
80+
@test (@constinferred p .+ bt) == tuplemortar(((6, 4), (), (2,)))
81+
@test (@constinferred bt .+ p) == tuplemortar(((6, 4), (), (2,)))
8082
@test (@constinferred blockedperm(p)) == p
8183
@test (@constinferred blockedperm(bt)) == p
8284

@@ -149,6 +151,8 @@ end
149151
@test (@constinferred blocks(tp)) == blocks(bt)
150152
@test (@constinferred map(identity, tp)) == bt
151153
@test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,)))
154+
@test (@constinferred tp .+ Tuple(tp)) == tuplemortar(((2, 4), (), (6,)))
155+
@test (@constinferred tp .+ BlockedTuple(tp)) == tuplemortar(((2, 4), (), (6,)))
152156
@test (@constinferred blockedperm(tp)) == tp
153157
@test (@constinferred trivialperm(tp)) == tp
154158
@test (@constinferred trivialperm(bt)) == tp

test/test_blockedtuple.jl

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
using Test: @test, @test_throws
1+
using Test: @test, @test_throws, @testset
22

3-
using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks
3+
using BlockArrays:
4+
Block, BlockVector, blocklength, blocklengths, blockedrange, blockisequal, blocks
45
using TestExtras: @constinferred
56

67
using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar
@@ -54,28 +55,70 @@ using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar
5455
BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1)
5556
@test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) ==
5657
BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1)
57-
@test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,)))
58+
@test (@constinferred bt .+ tuplemortar(((1,), (1, 1, 1), (1,)))) isa
59+
BlockedTuple{4,(1, 2, 1, 1),NTuple{5,Int64}}
60+
@test bt .+ tuplemortar(((1,), (1, 1, 1), (1,))) ==
61+
tuplemortar(((2,), (5, 3), (6,), (4,)))
5862

5963
bt = tuplemortar(((1:2, 1:2), (1:3,)))
6064
@test length.(bt) == tuplemortar(((2, 2), (3,)))
6165
@test length.(length.(bt)) == tuplemortar(((1, 1), (1,)))
6266

67+
bt = tuplemortar(((1,), (2,)))
68+
@test (@constinferred bt .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
69+
@test (bt .== bt) == tuplemortar(((true,), (true,)))
70+
@test (@constinferred bt .== tuplemortar(((1, 2),))) isa
71+
BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
72+
@test (bt .== tuplemortar(((1, 2),))) == tuplemortar(((true,), (true,)))
73+
@test_throws DimensionMismatch bt .== tuplemortar(((1,), (2,), (3,)))
74+
@test (@constinferred bt .== (1, 2)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
75+
@test (bt .== (1, 2)) == tuplemortar(((true,), (true,)))
76+
@test_throws DimensionMismatch bt .== (1, 2, 3)
77+
@test (@constinferred bt .== 1) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
78+
@test (bt .== 1) == tuplemortar(((true,), (false,)))
79+
@test (@constinferred bt .== (1,)) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
80+
81+
@test (bt .== (1,)) == tuplemortar(((true,), (false,)))
82+
# BlockedTuple .== AbstractVector is not type stable. Requires fix in BlockArrays
83+
@test (bt .== [1, 1]) isa BlockVector{Bool}
84+
@test blocks(bt .== [1, 1]) == [[true], [false]]
85+
@test_throws DimensionMismatch bt .== [1, 2, 3]
86+
87+
@test (@constinferred (1, 2) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
88+
@test ((1, 2) .== bt) == tuplemortar(((true,), (true,)))
89+
@test_throws DimensionMismatch (1, 2, 3) .== bt
90+
@test (@constinferred 1 .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
91+
@test (1 .== bt) == tuplemortar(((true,), (false,)))
92+
@test (@constinferred (1,) .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
93+
@test ((1,) .== bt) == tuplemortar(((true,), (false,)))
94+
@test ([1, 1] .== bt) isa BlockVector{Bool}
95+
@test blocks([1, 1] .== bt) == [[true], [false]]
96+
6397
# empty blocks
6498
bt = tuplemortar(((1,), (), (5, 3)))
6599
@test bt isa BlockedTuple{3}
66100
@test Tuple(bt) == (1, 5, 3)
67101
@test blocklengths(bt) == (1, 0, 2)
68102
@test (@constinferred blocks(bt)) == ((1,), (), (5, 3))
103+
@test blockisequal(only(axes(bt)), blockedrange([1, 0, 2]))
69104

70105
bt = tuplemortar(((), ()))
71106
@test bt isa BlockedTuple{2}
72107
@test Tuple(bt) == ()
73108
@test blocklengths(bt) == (0, 0)
74109
@test (@constinferred blocks(bt)) == ((), ())
75-
76-
bt = tuplemortar(())
77-
@test bt isa BlockedTuple{0}
78-
@test Tuple(bt) == ()
79-
@test blocklengths(bt) == ()
80-
@test (@constinferred blocks(bt)) == ()
110+
@test blockisequal(only(axes(bt)), blockedrange([0, 0]))
111+
@test bt == bt .+ bt
112+
113+
bt0 = tuplemortar(())
114+
bt1 = tuplemortar(((),))
115+
@test bt0 isa BlockedTuple{0}
116+
@test Tuple(bt0) == ()
117+
@test blocklengths(bt0) == ()
118+
@test (@constinferred blocks(bt0)) == ()
119+
@test blockisequal(only(axes(bt0)), blockedrange(zeros(Int, 0)))
120+
@test bt0 == bt0
121+
@test bt != bt1
122+
@test (@constinferred bt0 .+ bt0) == bt0
123+
@test (@constinferred bt0 .+ bt1) == bt1
81124
end

0 commit comments

Comments
 (0)