|
| 1 | +# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl |
| 2 | +# like interface |
| 3 | + |
| 4 | +using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange |
| 5 | + |
| 6 | +using TypeParameterAccessors: unspecify_type_parameters |
| 7 | + |
| 8 | +# |
| 9 | +# ================================== AbstractBlockTuple ================================== |
| 10 | +# |
| 11 | +abstract type AbstractBlockTuple end |
| 12 | + |
| 13 | +# Base interface |
| 14 | +Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),) |
| 15 | + |
| 16 | +Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt) |
| 17 | + |
| 18 | +Base.firstindex(::AbstractBlockTuple) = 1 |
| 19 | + |
| 20 | +Base.getindex(bt::AbstractBlockTuple, i::Integer) = Tuple(bt)[i] |
| 21 | +Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r] |
| 22 | +Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)] |
| 23 | +function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1}) |
| 24 | + r = Int.(br) |
| 25 | + T = unspecify_type_parameters(typeof(bt)) |
| 26 | + flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]] |
| 27 | + return T{blocklengths(bt)[r]}(flat) |
| 28 | +end |
| 29 | +function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1}) |
| 30 | + return bt[Block(bi)][only(bi.indices)] |
| 31 | +end |
| 32 | + |
| 33 | +Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt)) |
| 34 | +Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i) |
| 35 | + |
| 36 | +Base.length(bt::AbstractBlockTuple) = length(Tuple(bt)) |
| 37 | + |
| 38 | +Base.lastindex(bt::AbstractBlockTuple) = length(bt) |
| 39 | + |
| 40 | +function Base.map(f, bt::AbstractBlockTuple) |
| 41 | + return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt))) |
| 42 | +end |
| 43 | + |
| 44 | +# Broadcast interface |
| 45 | +Base.broadcastable(bt::AbstractBlockTuple) = bt |
| 46 | +struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end |
| 47 | +function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple}) |
| 48 | + return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}() |
| 49 | +end |
| 50 | + |
| 51 | +# BroadcastStyle is not called for two identical styles |
| 52 | +function Base.BroadcastStyle( |
| 53 | + ::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle |
| 54 | +) |
| 55 | + throw(DimensionMismatch("Incompatible blocks")) |
| 56 | +end |
| 57 | +function Base.copy( |
| 58 | + bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}} |
| 59 | +) where {BlockLengths,BT} |
| 60 | + return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...)) |
| 61 | +end |
| 62 | + |
| 63 | +# BlockArrays interface |
| 64 | +function BlockArrays.blockfirsts(bt::AbstractBlockTuple) |
| 65 | + return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1 |
| 66 | +end |
| 67 | + |
| 68 | +function BlockArrays.blocklasts(bt::AbstractBlockTuple) |
| 69 | + return cumsum(blocklengths(bt)[begin:end]) |
| 70 | +end |
| 71 | + |
| 72 | +BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt)) |
| 73 | + |
| 74 | +BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt)) |
| 75 | + |
| 76 | +function BlockArrays.blocks(bt::AbstractBlockTuple) |
| 77 | + bf = blockfirsts(bt) |
| 78 | + bl = blocklasts(bt) |
| 79 | + return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt)) |
| 80 | +end |
| 81 | + |
| 82 | +# |
| 83 | +# ===================================== BlockedTuple ===================================== |
| 84 | +# |
| 85 | +struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple |
| 86 | + flat::Flat |
| 87 | + |
| 88 | + function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths} |
| 89 | + length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length")) |
| 90 | + return new{BlockLengths,typeof(flat)}(flat) |
| 91 | + end |
| 92 | +end |
| 93 | + |
| 94 | +# TensorAlgebra Interface |
| 95 | +tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples(tt)) |
| 96 | +function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}}) |
| 97 | + return BlockedTuple{BlockLengths}(flat) |
| 98 | +end |
| 99 | +BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt)) |
| 100 | + |
| 101 | +# Base interface |
| 102 | +Base.Tuple(bt::BlockedTuple) = bt.flat |
| 103 | + |
| 104 | +# BlockArrays interface |
| 105 | +function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths} |
| 106 | + return BlockLengths |
| 107 | +end |
0 commit comments