Skip to content
1 change: 1 addition & 0 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TensorAlgebra
include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("blockedtuple.jl")
include("fusedims.jl")
include("splitdims.jl")
include("contract/contract.jl")
Expand Down
76 changes: 76 additions & 0 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl
# like interface

using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange

struct BlockedTuple{Divs,Flat}
flat::Flat

function BlockedTuple{Divs}(flat::Tuple) where {Divs}
length(flat) != sum(Divs) && throw(DimensionMismatch("Invalid total length"))
return new{Divs,typeof(flat)}(flat)

Check warning on line 11 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L9-L11

Added lines #L9 - L11 were not covered by tests
end
end

# TensorAlgebra Interface
BlockedTuple(tt::Vararg{Tuple}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
BlockedTuple(bt::BlockedTuple) = bt
flatten_tuples(bt::BlockedTuple) = Tuple(bt)

Check warning on line 18 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L16-L18

Added lines #L16 - L18 were not covered by tests

# Base interface
Base.Tuple(bt::BlockedTuple) = bt.flat

Check warning on line 21 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L21

Added line #L21 was not covered by tests

Base.axes(bt::BlockedTuple) = (blockedrange([blocklengths(bt)...]),)

Check warning on line 23 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L23

Added line #L23 was not covered by tests

Base.broadcastable(bt::BlockedTuple) = bt

Check warning on line 25 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L25

Added line #L25 was not covered by tests
struct BlockedTupleBroadcastStyle{Divs} <: Broadcast.BroadcastStyle end
function Base.BroadcastStyle(::Type{<:BlockedTuple{Divs}}) where {Divs}
return BlockedTupleBroadcastStyle{Divs}()

Check warning on line 28 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L27-L28

Added lines #L27 - L28 were not covered by tests
end
function Base.BroadcastStyle(::BlockedTupleBroadcastStyle, ::BlockedTupleBroadcastStyle)
throw(DimensionMismatch("Incompatible blocks"))

Check warning on line 31 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L30-L31

Added lines #L30 - L31 were not covered by tests
end
# BroadcastStyle is not called for two identical styles
function Base.copy(bc::Broadcast.Broadcasted{BlockedTupleBroadcastStyle{Divs}}) where {Divs}
return BlockedTuple{Divs}(bc.f.((Tuple.(bc.args))...))

Check warning on line 35 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
end

Base.copy(bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(copy.(Tuple(bt)))

Check warning on line 38 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L38

Added line #L38 was not covered by tests

Base.deepcopy(bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(deepcopy.(Tuple(bt)))

Check warning on line 40 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L40

Added line #L40 was not covered by tests

Base.firstindex(::BlockedTuple) = 1

Check warning on line 42 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L42

Added line #L42 was not covered by tests

Base.getindex(bt::BlockedTuple, i::Integer) = Tuple(bt)[i]
Base.getindex(bt::BlockedTuple, r::AbstractUnitRange) = Tuple(bt)[r]
Base.getindex(bt::BlockedTuple, b::Block{1}) = blocks(bt)[Int(b)]
Base.getindex(bt::BlockedTuple, br::BlockRange{1}) = blocks(bt)[Int.(br)]
Base.getindex(bt::BlockedTuple, bi::BlockIndexRange{1}) = bt[Block(bi)][only(bi.indices)]

Check warning on line 48 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L44-L48

Added lines #L44 - L48 were not covered by tests

Base.iterate(bt::BlockedTuple) = iterate(Tuple(bt))
Base.iterate(bt::BlockedTuple, i::Int) = iterate(Tuple(bt), i)

Check warning on line 51 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L50-L51

Added lines #L50 - L51 were not covered by tests

Base.lastindex(bt::BlockedTuple) = length(bt)

Check warning on line 53 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L53

Added line #L53 was not covered by tests

Base.length(bt::BlockedTuple) = length(Tuple(bt))

Check warning on line 55 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L55

Added line #L55 was not covered by tests

Base.map(f, bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(map(f, Tuple(bt)))

Check warning on line 57 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L57

Added line #L57 was not covered by tests

# BlockArrays interface
function BlockArrays.blockfirsts(bt::BlockedTuple)
return (0, cumsum(blocklengths(bt)[begin:(end - 1)])...) .+ 1

Check warning on line 61 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
end

function BlockArrays.blocklasts(bt::BlockedTuple)
return cumsum(blocklengths(bt)[begin:end])

Check warning on line 65 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L64-L65

Added lines #L64 - L65 were not covered by tests
end

BlockArrays.blocklength(::BlockedTuple{Divs}) where {Divs} = length(Divs)

Check warning on line 68 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L68

Added line #L68 was not covered by tests

BlockArrays.blocklengths(::BlockedTuple{Divs}) where {Divs} = Divs

Check warning on line 70 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L70

Added line #L70 was not covered by tests

function BlockArrays.blocks(bt::BlockedTuple)
bf = blockfirsts(bt)
bl = blocklasts(bt)
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))

Check warning on line 75 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L72-L75

Added lines #L72 - L75 were not covered by tests
end
52 changes: 52 additions & 0 deletions test/test_blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using Test: @test, @test_throws

using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks

using TensorAlgebra: BlockedTuple, flatten_tuples

@testset "BlockedTuple" begin
flat = (1, 'a', 2, 'b', 3)
divs = (1, 2, 2)

bt = BlockedTuple{divs}(flat)

@test Tuple(bt) == flat
@test flatten_tuples(bt) == flat
@test bt == BlockedTuple((1,), ('a', 2), ('b', 3))
@test BlockedTuple(bt) == bt
@test blocklength(bt) == 3
@test blocklengths(bt) == (1, 2, 2)
@test blocks(bt) == ((1,), ('a', 2), ('b', 3))

@test bt[1] == 1
@test bt[2] == 'a'
@test bt[Block(1)] == blocks(bt)[1]
@test bt[Block(2)] == blocks(bt)[2]
@test bt[Block(1):Block(2)] == blocks(bt)[1:2]
@test bt[Block(2)[1:2]] == ('a', 2)
@test bt[2:4] == ('a', 2, 'b')

@test firstindex(bt) == 1
@test lastindex(bt) == 5
@test length(bt) == 5

@test iterate(bt) == (1, 2)
@test iterate(bt, 2) == ('a', 3)
@test blockisequal(only(axes(bt)), blockedrange([1, 2, 2]))

@test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat)

bt = BlockedTuple((1,), (4, 2), (5, 3))
@test Tuple(bt) == (1, 4, 2, 5, 3)
@test blocklengths(bt) == (1, 2, 2)
@test copy(bt) == bt
@test deepcopy(bt) == bt

@test map(n -> n + 1, bt) == BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
@test bt .+ BlockedTuple((1,), (1, 1), (1, 1)) ==
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
@test_throws DimensionMismatch bt .+ BlockedTuple((1, 1), (1, 1), (1,))

bt = BlockedTuple((1:2, 1:2), (1:3,))
@test length.(bt) == BlockedTuple((2, 2), (3,))
end
Loading