Skip to content

Commit 20eaeb4

Browse files
committed
use widened_constructorof
1 parent 130e540 commit 20eaeb4

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/blockedtuple.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ using TypeParameterAccessors: unspecify_type_parameters
1010
#
1111
abstract type AbstractBlockTuple end
1212

13+
constructorof(type::Type{<:AbstractBlockTuple}) = unspecify_type_parameters(type)
14+
widened_constructorof(type::Type{<:AbstractBlockTuple}) = constructorof(type)
15+
1316
# Base interface
1417
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)
1518

@@ -22,9 +25,8 @@ Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r]
2225
Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)]
2326
function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1})
2427
r = Int.(br)
25-
T = unspecify_type_parameters(typeof(bt))
2628
flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]]
27-
return T{blocklengths(bt)[r]}(flat)
29+
return widened_constructorof(typeof(bt))(flat, blocklengths(bt)[r])
2830
end
2931
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1})
3032
return bt[Block(bi)][only(bi.indices)]
@@ -40,7 +42,7 @@ Base.lastindex(bt::AbstractBlockTuple) = length(bt)
4042
function Base.map(f, bt::AbstractBlockTuple)
4143
BL = blocklengths(bt)
4244
# use Val to preserve compile time knowledge of BL
43-
return unspecify_type_parameters(typeof(bt))(map(f, Tuple(bt)), Val(BL))
45+
return widened_constructorof(typeof(bt))(map(f, Tuple(bt)), Val(BL))
4446
end
4547

4648
# Broadcast interface
@@ -59,7 +61,7 @@ end
5961
function Base.copy(
6062
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
6163
) where {BlockLengths,BT}
62-
return BT(bc.f.((Tuple.(bc.args))...), Val(BlockLengths))
64+
return widened_constructorof(BT)(bc.f.((Tuple.(bc.args))...), Val(BlockLengths))
6365
end
6466

6567
# BlockArrays interface
@@ -89,6 +91,7 @@ struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
8991

9092
function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths}
9193
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length"))
94+
any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length"))
9295
return new{BlockLengths,typeof(flat)}(flat)
9396
end
9497
end

test/test_blockedtuple.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using TensorAlgebra: BlockedTuple, tuplemortar
1010
divs = (1, 2, 2)
1111

1212
bt = BlockedTuple{divs}(flat)
13+
@test bt isa BlockedTuple
1314

1415
@test (@constinferred Tuple(bt)) == flat
1516
@test bt == tuplemortar(((true,), ('a', 2), ("b", 3.0)))
@@ -21,6 +22,7 @@ using TensorAlgebra: BlockedTuple, tuplemortar
2122

2223
@test (@constinferred bt[1]) == true
2324
@test (@constinferred bt[2]) == 'a'
25+
@test (@constinferred map(identity, bt)) == bt
2426

2527
# it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block
2628
@test bt[Block(1)] == blocks(bt)[1]
@@ -40,6 +42,7 @@ using TensorAlgebra: BlockedTuple, tuplemortar
4042
@test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat)
4143

4244
bt = tuplemortar(((1,), (4, 2), (5, 3)))
45+
@test bt isa BlockedTuple
4346
@test Tuple(bt) == (1, 4, 2, 5, 3)
4447
@test blocklengths(bt) == (1, 2, 2)
4548
@test (@constinferred deepcopy(bt)) == bt

0 commit comments

Comments
 (0)