Skip to content

Commit 2b5922a

Browse files
committed
generic broadcast
1 parent 831c906 commit 2b5922a

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
99
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
12+
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1213

1314
[weakdeps]
1415
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
@@ -23,4 +24,5 @@ EllipsisNotation = "1.8.0"
2324
GradedUnitRanges = "0.1.0"
2425
LinearAlgebra = "1.10"
2526
TupleTools = "1.6.0"
27+
TypeParameterAccessors = "0.2.1"
2628
julia = "1.10"

src/blockedtuple.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange
55

6+
using TypeParameterAccessors: unspecify_type_parameters
7+
68
#
79
# ================================== AbstractBlockTuple ==================================
810
#
@@ -18,7 +20,12 @@ Base.firstindex(::AbstractBlockTuple) = 1
1820
Base.getindex(bt::AbstractBlockTuple, i::Integer) = Tuple(bt)[i]
1921
Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r]
2022
Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)]
21-
Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1}) = blocks(bt)[Int.(br)]
23+
function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1})
24+
r = Int.(br)
25+
T = unspecify_type_parameters(typeof(bt))
26+
flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]]
27+
return T{blocklengths(bt)[r]}(flat)
28+
end
2229
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1})
2330
return bt[Block(bi)][only(bi.indices)]
2431
end
@@ -32,19 +39,21 @@ Base.lastindex(bt::AbstractBlockTuple) = length(bt)
3239

3340
# Broadcast interface
3441
Base.broadcastable(bt::AbstractBlockTuple) = bt
35-
struct BlockedTupleBroadcastStyle{BlockLengths} <: Broadcast.BroadcastStyle end
36-
function Base.BroadcastStyle(type::Type{<:AbstractBlockTuple})
37-
return BlockedTupleBroadcastStyle{blocklengths(type)}()
42+
struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end
43+
function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple})
44+
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}()
3845
end
3946

4047
# BroadcastStyle is not called for two identical styles
41-
function Base.BroadcastStyle(::BlockedTupleBroadcastStyle, ::BlockedTupleBroadcastStyle)
48+
function Base.BroadcastStyle(
49+
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle
50+
)
4251
throw(DimensionMismatch("Incompatible blocks"))
4352
end
4453
function Base.copy(
45-
bc::Broadcast.Broadcasted{BlockedTupleBroadcastStyle{BlockLengths}}
46-
) where {BlockLengths}
47-
return BlockedTuple{BlockLengths}(bc.f.((Tuple.(bc.args))...))
54+
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
55+
) where {BlockLengths,BT}
56+
return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...))
4857
end
4958

5059
# BlockArrays interface

test/test_blockedtuple.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using TensorAlgebra: BlockedTuple, tuplemortar
2525
# it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block
2626
@test bt[Block(1)] == blocks(bt)[1]
2727
@test bt[Block(2)] == blocks(bt)[2]
28-
@test bt[Block(1):Block(2)] == blocks(bt)[1:2]
28+
@test bt[Block(1):Block(2)] == tuplemortar((true,), ('a', 2))
2929
@test bt[Block(2)[1:2]] == ('a', 2)
3030
@test bt[2:4] == ('a', 2, "b")
3131

0 commit comments

Comments
 (0)