Skip to content

Commit 2342979

Browse files
committed
reorder file
1 parent 05fc957 commit 2342979

File tree

2 files changed

+51
-41
lines changed

2 files changed

+51
-41
lines changed

src/blockedtuple.jl

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,19 @@
33

44
using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange
55

6+
#
7+
# ================================== AbstractBlockTuple ==================================
8+
#
69
abstract type AbstractBlockTuple end
710

8-
struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
9-
flat::Flat
10-
11-
function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths}
12-
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length"))
13-
return new{BlockLengths,typeof(flat)}(flat)
14-
end
15-
end
16-
17-
# TensorAlgebra Interface
18-
BlockedTuple(tt::Vararg{Tuple}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
19-
BlockedTuple(bt::BlockedTuple) = bt
20-
2111
# Base interface
22-
Base.Tuple(bt::BlockedTuple) = bt.flat
23-
2412
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)
2513

14+
Base.firstindex(::AbstractBlockTuple) = 1
15+
16+
Base.lastindex(bt::AbstractBlockTuple) = length(bt)
17+
18+
# Broadcast interface
2619
Base.broadcastable(bt::AbstractBlockTuple) = bt
2720
struct BlockedTupleBroadcastStyle{BlockLengths} <: Broadcast.BroadcastStyle end
2821
function Base.BroadcastStyle(type::Type{<:AbstractBlockTuple})
@@ -39,12 +32,48 @@ function Base.copy(
3932
return BlockedTuple{BlockLengths}(bc.f.((Tuple.(bc.args))...))
4033
end
4134

35+
# BlockArrays interface
36+
function BlockArrays.blockfirsts(bt::AbstractBlockTuple)
37+
return (0, cumsum(blocklengths(bt)[begin:(end - 1)])...) .+ 1
38+
end
39+
40+
function BlockArrays.blocklasts(bt::AbstractBlockTuple)
41+
return cumsum(blocklengths(bt)[begin:end])
42+
end
43+
44+
BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt))
45+
46+
BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt))
47+
48+
function BlockArrays.blocks(bt::AbstractBlockTuple)
49+
bf = blockfirsts(bt)
50+
bl = blocklasts(bt)
51+
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
52+
end
53+
54+
#
55+
# ===================================== BlockedTuple =====================================
56+
#
57+
struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
58+
flat::Flat
59+
60+
function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths}
61+
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length"))
62+
return new{BlockLengths,typeof(flat)}(flat)
63+
end
64+
end
65+
66+
# TensorAlgebra Interface
67+
BlockedTuple(tt::Vararg{Tuple}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
68+
BlockedTuple(bt::BlockedTuple) = bt
69+
70+
# Base interface
71+
Base.Tuple(bt::BlockedTuple) = bt.flat
72+
4273
Base.copy(bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(copy.(Tuple(bt)))
4374

4475
Base.deepcopy(bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(deepcopy.(Tuple(bt)))
4576

46-
Base.firstindex(::AbstractBlockTuple) = 1
47-
4877
Base.getindex(bt::BlockedTuple, i::Integer) = Tuple(bt)[i]
4978
Base.getindex(bt::BlockedTuple, r::AbstractUnitRange) = Tuple(bt)[r]
5079
Base.getindex(bt::BlockedTuple, b::Block{1}) = blocks(bt)[Int(b)]
@@ -54,30 +83,11 @@ Base.getindex(bt::BlockedTuple, bi::BlockIndexRange{1}) = bt[Block(bi)][only(bi.
5483
Base.iterate(bt::BlockedTuple) = iterate(Tuple(bt))
5584
Base.iterate(bt::BlockedTuple, i::Int) = iterate(Tuple(bt), i)
5685

57-
Base.lastindex(bt::AbstractBlockTuple) = length(bt)
58-
5986
Base.length(bt::BlockedTuple) = length(Tuple(bt))
6087

6188
Base.map(f, bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(map(f, Tuple(bt)))
6289

6390
# BlockArrays interface
64-
function BlockArrays.blockfirsts(bt::AbstractBlockTuple)
65-
return (0, cumsum(blocklengths(bt)[begin:(end - 1)])...) .+ 1
66-
end
67-
68-
function BlockArrays.blocklasts(bt::AbstractBlockTuple)
69-
return cumsum(blocklengths(bt)[begin:end])
70-
end
71-
72-
BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt))
73-
74-
BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt))
7591
function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths}
7692
return BlockLengths
7793
end
78-
79-
function BlockArrays.blocks(bt::AbstractBlockTuple)
80-
bf = blockfirsts(bt)
81-
bl = blocklasts(bt)
82-
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
83-
end

test/test_blockedtuple.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,25 @@ using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal,
55
using TensorAlgebra: BlockedTuple
66

77
@testset "BlockedTuple" begin
8-
flat = (1, 'a', 2, 'b', 3)
8+
flat = (true, 'a', 2, "b", 3.0)
99
divs = (1, 2, 2)
1010

1111
bt = BlockedTuple{divs}(flat)
1212

1313
@test Tuple(bt) == flat
14-
@test bt == BlockedTuple((1,), ('a', 2), ('b', 3))
14+
@test bt == BlockedTuple((true,), ('a', 2), ("b", 3.0))
1515
@test BlockedTuple(bt) == bt
1616
@test blocklength(bt) == 3
1717
@test blocklengths(bt) == (1, 2, 2)
18-
@test blocks(bt) == ((1,), ('a', 2), ('b', 3))
18+
@test blocks(bt) == ((true,), ('a', 2), ("b", 3.0))
1919

20-
@test bt[1] == 1
20+
@test bt[1] == true
2121
@test bt[2] == 'a'
2222
@test bt[Block(1)] == blocks(bt)[1]
2323
@test bt[Block(2)] == blocks(bt)[2]
2424
@test bt[Block(1):Block(2)] == blocks(bt)[1:2]
2525
@test bt[Block(2)[1:2]] == ('a', 2)
26-
@test bt[2:4] == ('a', 2, 'b')
26+
@test bt[2:4] == ('a', 2, "b")
2727

2828
@test firstindex(bt) == 1
2929
@test lastindex(bt) == 5

0 commit comments

Comments
 (0)