@@ -10,6 +10,9 @@ using TypeParameterAccessors: unspecify_type_parameters
1010#
1111abstract type AbstractBlockTuple end
1212
13+ constructorof (type:: Type{<:AbstractBlockTuple} ) = unspecify_type_parameters (type)
14+ widened_constructorof (type:: Type{<:AbstractBlockTuple} ) = constructorof (type)
15+
1316# Base interface
1417Base. axes (bt:: AbstractBlockTuple ) = (blockedrange ([blocklengths (bt)... ]),)
1518
@@ -22,9 +25,8 @@ Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r]
2225Base. getindex (bt:: AbstractBlockTuple , b:: Block{1} ) = blocks (bt)[Int (b)]
2326function Base. getindex (bt:: AbstractBlockTuple , br:: BlockRange{1} )
2427 r = Int .(br)
25- T = unspecify_type_parameters (typeof (bt))
2628 flat = Tuple (bt)[blockfirsts (bt)[first (r)]: blocklasts (bt)[last (r)]]
27- return T { blocklengths(bt)[r]} (flat )
29+ return widened_constructorof ( typeof (bt))(flat, blocklengths (bt)[r])
2830end
2931function Base. getindex (bt:: AbstractBlockTuple , bi:: BlockIndexRange{1} )
3032 return bt[Block (bi)][only (bi. indices)]
@@ -40,7 +42,7 @@ Base.lastindex(bt::AbstractBlockTuple) = length(bt)
4042function Base. map (f, bt:: AbstractBlockTuple )
4143 BL = blocklengths (bt)
4244 # use Val to preserve compile time knowledge of BL
43- return unspecify_type_parameters (typeof (bt))(map (f, Tuple (bt)), Val (BL))
45+ return widened_constructorof (typeof (bt))(map (f, Tuple (bt)), Val (BL))
4446end
4547
4648# Broadcast interface
5961function Base. copy (
6062 bc:: Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
6163) where {BlockLengths,BT}
62- return BT (bc. f .((Tuple .(bc. args)). .. ), Val (BlockLengths))
64+ return widened_constructorof (BT) (bc. f .((Tuple .(bc. args)). .. ), Val (BlockLengths))
6365end
6466
6567# BlockArrays interface
@@ -89,6 +91,7 @@ struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
8991
9092 function BlockedTuple {BlockLengths} (flat:: Tuple ) where {BlockLengths}
9193 length (flat) != sum (BlockLengths) && throw (DimensionMismatch (" Invalid total length" ))
94+ any (BlockLengths .< 0 ) && throw (DimensionMismatch (" Invalid block length" ))
9295 return new {BlockLengths,typeof(flat)} (flat)
9396 end
9497end
0 commit comments