Skip to content

Commit 200f88d

Browse files
committed
define BlockedTuple
1 parent 74c6607 commit 200f88d

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

src/TensorAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TensorAlgebra
22
include("blockedpermutation.jl")
33
include("BaseExtensions/BaseExtensions.jl")
4+
include("blockedtuple.jl")
45
include("fusedims.jl")
56
include("splitdims.jl")
67
include("contract/contract.jl")

src/blockedtuple.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl
2+
# like interface
3+
4+
using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange
5+
6+
struct BlockedTuple{Divs,Flat}
7+
flat::Flat
8+
9+
function BlockedTuple{Divs}(flat::Tuple) where {Divs}
10+
length(flat) != sum(Divs) && throw(DimensionMismatch("Invalid total length"))
11+
return new{Divs,typeof(flat)}(flat)
12+
end
13+
end
14+
15+
# TensorAlgebra Interface
16+
BlockedTuple(tt::Vararg{Tuple}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
17+
BlockedTuple(bt::BlockedTuple) = bt
18+
flatten_tuples(bt::BlockedTuple) = Tuple(bt)
19+
20+
# Base interface
21+
Base.Tuple(bt::BlockedTuple) = bt.flat
22+
23+
Base.axes(bt::BlockedTuple) = (blockedrange([blocklengths(bt)...]),)
24+
25+
Base.broadcastable(bt::BlockedTuple) = bt
26+
struct BlockedTupleBroadcastStyle{Divs} <: Broadcast.BroadcastStyle end
27+
function Base.BroadcastStyle(::Type{<:BlockedTuple{Divs}}) where {Divs}
28+
return BlockedTupleBroadcastStyle{Divs}()
29+
end
30+
function Base.BroadcastStyle(::BlockedTupleBroadcastStyle, ::BlockedTupleBroadcastStyle)
31+
throw(DimensionMismatch("Incompatible blocks"))
32+
end
33+
Base.BroadcastStyle(::T, ::T) where {T<:BlockedTupleBroadcastStyle} = T
34+
function Base.copy(bc::Broadcast.Broadcasted{BlockedTupleBroadcastStyle{Divs}}) where {Divs}
35+
return BlockedTuple{Divs}(bc.f.((Tuple.(bc.args))...))
36+
end
37+
38+
Base.copy(bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(copy.(Tuple(bt)))
39+
40+
Base.deepcopy(bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(deepcopy.(Tuple(bt)))
41+
42+
Base.firstindex(::BlockedTuple) = 1
43+
44+
Base.getindex(bt::BlockedTuple, i::Integer) = Tuple(bt)[i]
45+
Base.getindex(bt::BlockedTuple, r::AbstractUnitRange) = Tuple(bt)[r]
46+
Base.getindex(bt::BlockedTuple, b::Block{1}) = blocks(bt)[Int(b)]
47+
Base.getindex(bt::BlockedTuple, br::BlockRange{1}) = blocks(bt)[Int.(br)]
48+
Base.getindex(bt::BlockedTuple, bi::BlockIndexRange{1}) = bt[Block(bi)][only(bi.indices)]
49+
50+
Base.iterate(bt::BlockedTuple) = iterate(Tuple(bt))
51+
Base.iterate(bt::BlockedTuple, i::Int) = iterate(Tuple(bt), i)
52+
53+
Base.lastindex(bt::BlockedTuple) = length(bt)
54+
55+
Base.length(bt::BlockedTuple) = length(Tuple(bt))
56+
57+
Base.map(f, bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(map(f, Tuple(bt)))
58+
59+
# BlockArrays interface
60+
function BlockArrays.blockfirsts(bt::BlockedTuple)
61+
return (0, cumsum(blocklengths(bt)[begin:(end - 1)])...) .+ 1
62+
end
63+
64+
function BlockArrays.blocklasts(bt::BlockedTuple)
65+
return cumsum(blocklengths(bt)[begin:end])
66+
end
67+
68+
BlockArrays.blocklength(::BlockedTuple{Divs}) where {Divs} = length(Divs)
69+
70+
BlockArrays.blocklengths(::BlockedTuple{Divs}) where {Divs} = Divs
71+
72+
function BlockArrays.blocks(bt::BlockedTuple)
73+
bf = blockfirsts(bt)
74+
bl = blocklasts(bt)
75+
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
76+
end

test/test_blockedtuple.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using Test: @test, @test_throws
2+
3+
using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks
4+
5+
using TensorAlgebra: BlockedTuple, flatten_tuples
6+
7+
@testset "BlockedTuple" begin
8+
flat = (1, 'a', 2, 'b', 3)
9+
divs = (1, 2, 2)
10+
11+
bt = BlockedTuple{divs}(flat)
12+
13+
@test Tuple(bt) == flat
14+
@test flatten_tuples(bt) == flat
15+
@test bt == BlockedTuple((1,), ('a', 2), ('b', 3))
16+
@test BlockedTuple(bt) == bt
17+
@test blocklength(bt) == 3
18+
@test blocklengths(bt) == (1, 2, 2)
19+
@test blocks(bt) == ((1,), ('a', 2), ('b', 3))
20+
21+
@test bt[1] == 1
22+
@test bt[2] == 'a'
23+
@test bt[Block(1)] == blocks(bt)[1]
24+
@test bt[Block(2)] == blocks(bt)[2]
25+
@test bt[Block(1):Block(2)] == blocks(bt)[1:2]
26+
@test bt[Block(2)[1:2]] == ('a', 2)
27+
28+
@test firstindex(bt) == 1
29+
@test lastindex(bt) == 5
30+
@test length(bt) == 5
31+
32+
@test iterate(bt) == (1, 2)
33+
@test iterate(bt, 2) == ('a', 3)
34+
@test blockisequal(only(axes(bt)), blockedrange([1, 2, 2]))
35+
36+
@test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat)
37+
38+
bt = BlockedTuple((1,), (4, 2), (5, 3))
39+
@test Tuple(bt) == (1, 4, 2, 5, 3)
40+
@test blocklengths(bt) == (1, 2, 2)
41+
@test copy(bt) == bt
42+
@test deepcopy(bt) == bt
43+
44+
@test map(n -> n + 1, bt) == BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
45+
@test bt .+ BlockedTuple((1,), (1, 1), (1, 1)) ==
46+
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
47+
@test_throws DimensionMismatch bt .+ BlockedTuple((1, 1), (1, 1), (1,))
48+
49+
bt = BlockedTuple((1:2, 1:2), (1:3,))
50+
@test length.(bt) == BlockedTuple((2, 2), (3,))
51+
end

0 commit comments

Comments
 (0)