Skip to content

Commit fae747c

Browse files
authored
Define BlockedTuple (#9)
1 parent 74c6607 commit fae747c

File tree

5 files changed

+171
-5
lines changed

5 files changed

+171
-5
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
88
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
99
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
12+
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1213

1314
[weakdeps]
1415
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
@@ -23,4 +24,5 @@ EllipsisNotation = "1.8.0"
2324
GradedUnitRanges = "0.1.0"
2425
LinearAlgebra = "1.10"
2526
TupleTools = "1.6.0"
27+
TypeParameterAccessors = "0.2.1"
2628
julia = "1.10"

src/TensorAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TensorAlgebra
22
include("blockedpermutation.jl")
33
include("BaseExtensions/BaseExtensions.jl")
4+
include("blockedtuple.jl")
45
include("fusedims.jl")
56
include("splitdims.jl")
67
include("contract/contract.jl")

src/blockedtuple.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl
2+
# like interface
3+
4+
using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange
5+
6+
using TypeParameterAccessors: unspecify_type_parameters
7+
8+
#
9+
# ================================== AbstractBlockTuple ==================================
10+
#
11+
abstract type AbstractBlockTuple end
12+
13+
# Base interface
14+
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)
15+
16+
Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt)
17+
18+
Base.firstindex(::AbstractBlockTuple) = 1
19+
20+
Base.getindex(bt::AbstractBlockTuple, i::Integer) = Tuple(bt)[i]
21+
Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r]
22+
Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)]
23+
function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1})
24+
r = Int.(br)
25+
T = unspecify_type_parameters(typeof(bt))
26+
flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]]
27+
return T{blocklengths(bt)[r]}(flat)
28+
end
29+
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1})
30+
return bt[Block(bi)][only(bi.indices)]
31+
end
32+
33+
Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt))
34+
Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i)
35+
36+
Base.length(bt::AbstractBlockTuple) = length(Tuple(bt))
37+
38+
Base.lastindex(bt::AbstractBlockTuple) = length(bt)
39+
40+
function Base.map(f, bt::AbstractBlockTuple)
41+
return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt)))
42+
end
43+
44+
# Broadcast interface
45+
Base.broadcastable(bt::AbstractBlockTuple) = bt
46+
struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end
47+
function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple})
48+
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}()
49+
end
50+
51+
# BroadcastStyle is not called for two identical styles
52+
function Base.BroadcastStyle(
53+
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle
54+
)
55+
throw(DimensionMismatch("Incompatible blocks"))
56+
end
57+
function Base.copy(
58+
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
59+
) where {BlockLengths,BT}
60+
return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...))
61+
end
62+
63+
# BlockArrays interface
64+
function BlockArrays.blockfirsts(bt::AbstractBlockTuple)
65+
return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1
66+
end
67+
68+
function BlockArrays.blocklasts(bt::AbstractBlockTuple)
69+
return cumsum(blocklengths(bt)[begin:end])
70+
end
71+
72+
BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt))
73+
74+
BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt))
75+
76+
function BlockArrays.blocks(bt::AbstractBlockTuple)
77+
bf = blockfirsts(bt)
78+
bl = blocklasts(bt)
79+
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
80+
end
81+
82+
#
83+
# ===================================== BlockedTuple =====================================
84+
#
85+
struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
86+
flat::Flat
87+
88+
function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths}
89+
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length"))
90+
return new{BlockLengths,typeof(flat)}(flat)
91+
end
92+
end
93+
94+
# TensorAlgebra Interface
95+
tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
96+
function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}})
97+
return BlockedTuple{BlockLengths}(flat)
98+
end
99+
BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt))
100+
101+
# Base interface
102+
Base.Tuple(bt::BlockedTuple) = bt.flat
103+
104+
# BlockArrays interface
105+
function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths}
106+
return BlockLengths
107+
end

test/Project.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
23
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
34
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
45
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
@@ -8,16 +9,16 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
1011
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
11-
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
12-
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
13-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1412
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1513
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
14+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
15+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
1616
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1718

1819
[compat]
19-
TensorOperations = "4.1.1"
2020
Aqua = "0.8.9"
2121
SafeTestsets = "0.1"
2222
Suppressor = "0.2"
23+
TensorOperations = "5.1.3"
2324
Test = "1.10"

test/test_blockedtuple.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using Test: @test, @test_throws
2+
3+
using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks
4+
using TestExtras: @constinferred
5+
6+
using TensorAlgebra: BlockedTuple, tuplemortar
7+
8+
@testset "BlockedTuple" begin
9+
flat = (true, 'a', 2, "b", 3.0)
10+
divs = (1, 2, 2)
11+
12+
bt = BlockedTuple{divs}(flat)
13+
14+
@test (@constinferred Tuple(bt)) == flat
15+
@test bt == tuplemortar(((true,), ('a', 2), ("b", 3.0)))
16+
@test bt == BlockedTuple(flat, divs)
17+
@test BlockedTuple(bt) == bt
18+
@test blocklength(bt) == 3
19+
@test blocklengths(bt) == (1, 2, 2)
20+
@test (@constinferred blocks(bt)) == ((true,), ('a', 2), ("b", 3.0))
21+
22+
@test (@constinferred bt[1]) == true
23+
@test (@constinferred bt[2]) == 'a'
24+
25+
# it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block
26+
@test bt[Block(1)] == blocks(bt)[1]
27+
@test bt[Block(2)] == blocks(bt)[2]
28+
@test bt[Block(1):Block(2)] == tuplemortar(((true,), ('a', 2)))
29+
@test bt[Block(2)[1:2]] == ('a', 2)
30+
@test bt[2:4] == ('a', 2, "b")
31+
32+
@test firstindex(bt) == 1
33+
@test lastindex(bt) == 5
34+
@test length(bt) == 5
35+
36+
@test iterate(bt) == (1, 2)
37+
@test iterate(bt, 2) == ('a', 3)
38+
@test blockisequal(only(axes(bt)), blockedrange([1, 2, 2]))
39+
40+
@test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat)
41+
42+
bt = tuplemortar(((1,), (4, 2), (5, 3)))
43+
@test Tuple(bt) == (1, 4, 2, 5, 3)
44+
@test blocklengths(bt) == (1, 2, 2)
45+
@test deepcopy(bt) == bt
46+
47+
@test (@constinferred map(n -> n + 1, bt)) ==
48+
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
49+
@test bt .+ tuplemortar(((1,), (1, 1), (1, 1))) ==
50+
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
51+
@test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,)))
52+
53+
bt = tuplemortar(((1:2, 1:2), (1:3,)))
54+
@test length.(bt) == tuplemortar(((2, 2), (3,)))
55+
end

0 commit comments

Comments
 (0)