Skip to content

Commit 130e540

Browse files
committed
no assumption on AbstractBlockTuple parameters
1 parent fae747c commit 130e540

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/blockedtuple.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ Base.length(bt::AbstractBlockTuple) = length(Tuple(bt))
3838
Base.lastindex(bt::AbstractBlockTuple) = length(bt)
3939

4040
function Base.map(f, bt::AbstractBlockTuple)
41-
return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt)))
41+
BL = blocklengths(bt)
42+
# use Val to preserve compile time knowledge of BL
43+
return unspecify_type_parameters(typeof(bt))(map(f, Tuple(bt)), Val(BL))
4244
end
4345

4446
# Broadcast interface
@@ -57,7 +59,7 @@ end
5759
function Base.copy(
5860
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
5961
) where {BlockLengths,BT}
60-
return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...))
62+
return BT(bc.f.((Tuple.(bc.args))...), Val(BlockLengths))
6163
end
6264

6365
# BlockArrays interface
@@ -96,6 +98,10 @@ tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples
9698
function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}})
9799
return BlockedTuple{BlockLengths}(flat)
98100
end
101+
function BlockedTuple(flat::Tuple, ::Val{BlockLengths}) where {BlockLengths}
102+
# use Val to preserve compile time knowledge of BL
103+
return BlockedTuple{BlockLengths}(flat)
104+
end
99105
BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt))
100106

101107
# Base interface

test/test_blockedtuple.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ using TensorAlgebra: BlockedTuple, tuplemortar
4242
bt = tuplemortar(((1,), (4, 2), (5, 3)))
4343
@test Tuple(bt) == (1, 4, 2, 5, 3)
4444
@test blocklengths(bt) == (1, 2, 2)
45-
@test deepcopy(bt) == bt
45+
@test (@constinferred deepcopy(bt)) == bt
4646

4747
@test (@constinferred map(n -> n + 1, bt)) ==
4848
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
49-
@test bt .+ tuplemortar(((1,), (1, 1), (1, 1))) ==
49+
@test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) ==
5050
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
5151
@test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,)))
5252

0 commit comments

Comments
 (0)