Skip to content
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[weakdeps]
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
Expand All @@ -23,4 +24,5 @@ EllipsisNotation = "1.8.0"
GradedUnitRanges = "0.1.0"
LinearAlgebra = "1.10"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.2.1"
julia = "1.10"
1 change: 1 addition & 0 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TensorAlgebra
include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("blockedtuple.jl")
include("fusedims.jl")
include("splitdims.jl")
include("contract/contract.jl")
Expand Down
107 changes: 107 additions & 0 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl
# like interface

using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange

using TypeParameterAccessors: unspecify_type_parameters

#
# ================================== AbstractBlockTuple ==================================
#
abstract type AbstractBlockTuple end

# Base interface
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)

Check warning on line 14 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L14

Added line #L14 was not covered by tests

Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt)

Check warning on line 16 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L16

Added line #L16 was not covered by tests

Base.firstindex(::AbstractBlockTuple) = 1

Check warning on line 18 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L18

Added line #L18 was not covered by tests

Base.getindex(bt::AbstractBlockTuple, i::Integer) = Tuple(bt)[i]
Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r]
Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)]
function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1})
r = Int.(br)
T = unspecify_type_parameters(typeof(bt))
flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]]
return T{blocklengths(bt)[r]}(flat)

Check warning on line 27 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L20-L27

Added lines #L20 - L27 were not covered by tests
end
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1})
return bt[Block(bi)][only(bi.indices)]

Check warning on line 30 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L29-L30

Added lines #L29 - L30 were not covered by tests
end

Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt))
Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i)

Check warning on line 34 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L33-L34

Added lines #L33 - L34 were not covered by tests

Base.length(bt::AbstractBlockTuple) = length(Tuple(bt))

Check warning on line 36 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L36

Added line #L36 was not covered by tests

Base.lastindex(bt::AbstractBlockTuple) = length(bt)

Check warning on line 38 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L38

Added line #L38 was not covered by tests

function Base.map(f, bt::AbstractBlockTuple)
return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt)))

Check warning on line 41 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end

# Broadcast interface
Base.broadcastable(bt::AbstractBlockTuple) = bt

Check warning on line 45 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L45

Added line #L45 was not covered by tests
struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end
function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple})
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}()

Check warning on line 48 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end

# BroadcastStyle is not called for two identical styles
function Base.BroadcastStyle(

Check warning on line 52 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L52

Added line #L52 was not covered by tests
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle
)
throw(DimensionMismatch("Incompatible blocks"))

Check warning on line 55 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L55

Added line #L55 was not covered by tests
end
function Base.copy(

Check warning on line 57 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L57

Added line #L57 was not covered by tests
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
) where {BlockLengths,BT}
return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...))

Check warning on line 60 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L60

Added line #L60 was not covered by tests
end

# BlockArrays interface
function BlockArrays.blockfirsts(bt::AbstractBlockTuple)
return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1

Check warning on line 65 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L64-L65

Added lines #L64 - L65 were not covered by tests
end

function BlockArrays.blocklasts(bt::AbstractBlockTuple)
return cumsum(blocklengths(bt)[begin:end])

Check warning on line 69 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
end

BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt))

Check warning on line 72 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L72

Added line #L72 was not covered by tests

BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt))

Check warning on line 74 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L74

Added line #L74 was not covered by tests

function BlockArrays.blocks(bt::AbstractBlockTuple)
bf = blockfirsts(bt)
bl = blocklasts(bt)
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))

Check warning on line 79 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L76-L79

Added lines #L76 - L79 were not covered by tests
end

#
# ===================================== BlockedTuple =====================================
#
struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
flat::Flat

function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths}
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length"))
return new{BlockLengths,typeof(flat)}(flat)

Check warning on line 90 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L88-L90

Added lines #L88 - L90 were not covered by tests
end
end

# TensorAlgebra Interface
tuplemortar(tt::Vararg{Tuple}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}})
return BlockedTuple{BlockLengths}(flat)

Check warning on line 97 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L95-L97

Added lines #L95 - L97 were not covered by tests
end
BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt))

Check warning on line 99 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L99

Added line #L99 was not covered by tests

# Base interface
Base.Tuple(bt::BlockedTuple) = bt.flat

Check warning on line 102 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L102

Added line #L102 was not covered by tests

# BlockArrays interface
function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths}
return BlockLengths

Check warning on line 106 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L105-L106

Added lines #L105 - L106 were not covered by tests
end
9 changes: 5 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
Expand All @@ -8,16 +9,16 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
TensorOperations = "4.1.1"
Aqua = "0.8.9"
SafeTestsets = "0.1"
Suppressor = "0.2"
TensorOperations = "4.1.1"
Test = "1.10"
55 changes: 55 additions & 0 deletions test/test_blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Test: @test, @test_throws

using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks
using TestExtras: @constinferred

using TensorAlgebra: BlockedTuple, tuplemortar

@testset "BlockedTuple" begin
flat = (true, 'a', 2, "b", 3.0)
divs = (1, 2, 2)

bt = BlockedTuple{divs}(flat)

@test (@constinferred Tuple(bt)) == flat
@test bt == tuplemortar((true,), ('a', 2), ("b", 3.0))
@test bt == BlockedTuple(flat, divs)
@test BlockedTuple(bt) == bt
@test blocklength(bt) == 3
@test blocklengths(bt) == (1, 2, 2)
@test (@constinferred blocks(bt)) == ((true,), ('a', 2), ("b", 3.0))

@test (@constinferred bt[1]) == true
@test (@constinferred bt[2]) == 'a'

# it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block
@test bt[Block(1)] == blocks(bt)[1]
@test bt[Block(2)] == blocks(bt)[2]
@test bt[Block(1):Block(2)] == tuplemortar((true,), ('a', 2))
@test bt[Block(2)[1:2]] == ('a', 2)
@test bt[2:4] == ('a', 2, "b")

@test firstindex(bt) == 1
@test lastindex(bt) == 5
@test length(bt) == 5

@test iterate(bt) == (1, 2)
@test iterate(bt, 2) == ('a', 3)
@test blockisequal(only(axes(bt)), blockedrange([1, 2, 2]))

@test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat)

bt = tuplemortar((1,), (4, 2), (5, 3))
@test Tuple(bt) == (1, 4, 2, 5, 3)
@test blocklengths(bt) == (1, 2, 2)
@test deepcopy(bt) == bt

@test (@constinferred map(n -> n + 1, bt)) ==
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
@test bt .+ tuplemortar((1,), (1, 1), (1, 1)) ==
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
@test_throws DimensionMismatch bt .+ tuplemortar((1, 1), (1, 1), (1,))

bt = tuplemortar((1:2, 1:2), (1:3,))
@test length.(bt) == tuplemortar((2, 2), (3,))
end
Loading