33
44using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange
55
6+ #
7+ # ================================== AbstractBlockTuple ==================================
8+ #
69abstract type AbstractBlockTuple end
710
8- struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
9- flat:: Flat
10-
11- function BlockedTuple {BlockLengths} (flat:: Tuple ) where {BlockLengths}
12- length (flat) != sum (BlockLengths) && throw (DimensionMismatch (" Invalid total length" ))
13- return new {BlockLengths,typeof(flat)} (flat)
14- end
15- end
16-
17- # TensorAlgebra Interface
18- BlockedTuple (tt:: Vararg{Tuple} ) = BlockedTuple {length.(tt)} (flatten_tuples (tt))
19- BlockedTuple (bt:: BlockedTuple ) = bt
20-
2111# Base interface
22- Base. Tuple (bt:: BlockedTuple ) = bt. flat
23-
2412Base. axes (bt:: AbstractBlockTuple ) = (blockedrange ([blocklengths (bt)... ]),)
2513
14+ Base. firstindex (:: AbstractBlockTuple ) = 1
15+
16+ Base. lastindex (bt:: AbstractBlockTuple ) = length (bt)
17+
18+ # Broadcast interface
2619Base. broadcastable (bt:: AbstractBlockTuple ) = bt
2720struct BlockedTupleBroadcastStyle{BlockLengths} <: Broadcast.BroadcastStyle end
2821function Base. BroadcastStyle (type:: Type{<:AbstractBlockTuple} )
@@ -39,12 +32,48 @@ function Base.copy(
3932 return BlockedTuple {BlockLengths} (bc. f .((Tuple .(bc. args)). .. ))
4033end
4134
35+ # BlockArrays interface
36+ function BlockArrays. blockfirsts (bt:: AbstractBlockTuple )
37+ return (0 , cumsum (blocklengths (bt)[begin : (end - 1 )])... ) .+ 1
38+ end
39+
40+ function BlockArrays. blocklasts (bt:: AbstractBlockTuple )
41+ return cumsum (blocklengths (bt)[begin : end ])
42+ end
43+
44+ BlockArrays. blocklength (bt:: AbstractBlockTuple ) = length (blocklengths (bt))
45+
46+ BlockArrays. blocklengths (bt:: AbstractBlockTuple ) = blocklengths (typeof (bt))
47+
48+ function BlockArrays. blocks (bt:: AbstractBlockTuple )
49+ bf = blockfirsts (bt)
50+ bl = blocklasts (bt)
51+ return ntuple (i -> Tuple (bt)[bf[i]: bl[i]], blocklength (bt))
52+ end
53+
54+ #
55+ # ===================================== BlockedTuple =====================================
56+ #
57+ struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
58+ flat:: Flat
59+
60+ function BlockedTuple {BlockLengths} (flat:: Tuple ) where {BlockLengths}
61+ length (flat) != sum (BlockLengths) && throw (DimensionMismatch (" Invalid total length" ))
62+ return new {BlockLengths,typeof(flat)} (flat)
63+ end
64+ end
65+
66+ # TensorAlgebra Interface
67+ BlockedTuple (tt:: Vararg{Tuple} ) = BlockedTuple {length.(tt)} (flatten_tuples (tt))
68+ BlockedTuple (bt:: BlockedTuple ) = bt
69+
70+ # Base interface
71+ Base. Tuple (bt:: BlockedTuple ) = bt. flat
72+
4273Base. copy (bt:: BlockedTuple ) = BlockedTuple {blocklengths(bt)} (copy .(Tuple (bt)))
4374
4475Base. deepcopy (bt:: BlockedTuple ) = BlockedTuple {blocklengths(bt)} (deepcopy .(Tuple (bt)))
4576
46- Base. firstindex (:: AbstractBlockTuple ) = 1
47-
4877Base. getindex (bt:: BlockedTuple , i:: Integer ) = Tuple (bt)[i]
4978Base. getindex (bt:: BlockedTuple , r:: AbstractUnitRange ) = Tuple (bt)[r]
5079Base. getindex (bt:: BlockedTuple , b:: Block{1} ) = blocks (bt)[Int (b)]
@@ -54,30 +83,11 @@ Base.getindex(bt::BlockedTuple, bi::BlockIndexRange{1}) = bt[Block(bi)][only(bi.
5483Base. iterate (bt:: BlockedTuple ) = iterate (Tuple (bt))
5584Base. iterate (bt:: BlockedTuple , i:: Int ) = iterate (Tuple (bt), i)
5685
57- Base. lastindex (bt:: AbstractBlockTuple ) = length (bt)
58-
5986Base. length (bt:: BlockedTuple ) = length (Tuple (bt))
6087
6188Base. map (f, bt:: BlockedTuple ) = BlockedTuple {blocklengths(bt)} (map (f, Tuple (bt)))
6289
6390# BlockArrays interface
64- function BlockArrays. blockfirsts (bt:: AbstractBlockTuple )
65- return (0 , cumsum (blocklengths (bt)[begin : (end - 1 )])... ) .+ 1
66- end
67-
68- function BlockArrays. blocklasts (bt:: AbstractBlockTuple )
69- return cumsum (blocklengths (bt)[begin : end ])
70- end
71-
72- BlockArrays. blocklength (bt:: AbstractBlockTuple ) = length (blocklengths (bt))
73-
74- BlockArrays. blocklengths (bt:: AbstractBlockTuple ) = blocklengths (typeof (bt))
7591function BlockArrays. blocklengths (:: Type{<:BlockedTuple{BlockLengths}} ) where {BlockLengths}
7692 return BlockLengths
7793end
78-
79- function BlockArrays. blocks (bt:: AbstractBlockTuple )
80- bf = blockfirsts (bt)
81- bl = blocklasts (bt)
82- return ntuple (i -> Tuple (bt)[bf[i]: bl[i]], blocklength (bt))
83- end
0 commit comments