Skip to content

Commit ab98cf3

Browse files
committed
first draft for blockedpermutation
1 parent 20eaeb4 commit ab98cf3

File tree

4 files changed

+90
-101
lines changed

4 files changed

+90
-101
lines changed

src/TensorAlgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module TensorAlgebra
2+
include("blockedtuple.jl")
23
include("blockedpermutation.jl")
34
include("BaseExtensions/BaseExtensions.jl")
4-
include("blockedtuple.jl")
55
include("fusedims.jl")
66
include("splitdims.jl")
77
include("contract/contract.jl")

src/blockedpermutation.jl

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,30 +28,12 @@ collect_tuple(t::Tuple) = t
2828

2929
const TupleOfTuples{N} = Tuple{Vararg{Tuple{Vararg{Int}},N}}
3030

31-
abstract type AbstractBlockedPermutation{BlockLength,Length} end
32-
33-
BlockArrays.blocks(blockedperm::AbstractBlockedPermutation) = error("Not implemented")
34-
35-
function Base.Tuple(blockedperm::AbstractBlockedPermutation)
36-
return flatten_tuples(blocks(blockedperm))
37-
end
38-
39-
function BlockArrays.blocklengths(blockedperm::AbstractBlockedPermutation)
40-
return length.(blocks(blockedperm))
41-
end
42-
43-
function BlockArrays.blockfirsts(blockedperm::AbstractBlockedPermutation)
44-
return _blockfirsts(blocklengths(blockedperm))
45-
end
46-
47-
function BlockArrays.blocklasts(blockedperm::AbstractBlockedPermutation)
48-
return _blocklasts(blocklengths(blockedperm))
49-
end
31+
#
32+
# ============================== AbstractBlockedPermutation ==============================
33+
#
34+
abstract type AbstractBlockedPermutation{BlockLength,Length} <: AbstractBlockTuple end
5035

51-
Base.iterate(permblocks::AbstractBlockedPermutation) = iterate(Tuple(permblocks))
52-
function Base.iterate(permblocks::AbstractBlockedPermutation, state)
53-
return iterate(Tuple(permblocks), state)
54-
end
36+
widened_constructorof(::Type{<:AbstractBlockedPermutation}) = BlockedTuple
5537

5638
# Block a permutation based on the specified lengths.
5739
# blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1))
@@ -67,24 +49,6 @@ function Base.invperm(blockedperm::AbstractBlockedPermutation)
6749
return blockperm(invperm(Tuple(blockedperm)), blocklengths(blockedperm))
6850
end
6951

70-
Base.length(blockedperm::AbstractBlockedPermutation) = length(Tuple(blockedperm))
71-
function BlockArrays.blocklength(blockedperm::AbstractBlockedPermutation)
72-
return length(blocks(blockedperm))
73-
end
74-
75-
function Base.getindex(blockedperm::AbstractBlockedPermutation, i::Int)
76-
return Tuple(blockedperm)[i]
77-
end
78-
79-
function Base.getindex(blockedperm::AbstractBlockedPermutation, I::AbstractUnitRange)
80-
perm = Tuple(blockedperm)
81-
return [perm[i] for i in I]
82-
end
83-
84-
function Base.getindex(blockedperm::AbstractBlockedPermutation, b::Block)
85-
return blocks(blockedperm)[Int(b)]
86-
end
87-
8852
# Like `BlockRange`.
8953
function blockeachindex(blockedperm::AbstractBlockedPermutation)
9054
return ntuple(i -> Block(i), blocklength(blockedperm))
@@ -148,6 +112,9 @@ function blockedperm_indexin(collection, subs...)
148112
return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)
149113
end
150114

115+
#
116+
# ================================== BlockedPermutation ==================================
117+
#
151118
struct BlockedPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}} <:
152119
AbstractBlockedPermutation{BlockLength,Length}
153120
blocks::Blocks
@@ -158,15 +125,28 @@ struct BlockedPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}}
158125
end
159126
end
160127

128+
Base.Tuple(blockedperm::BlockedPermutation) = flatten_tuples(blocks(blockedperm))
129+
130+
BlockedTuple(bp::BlockedPermutation) = tuplemortar(blocks(bp))
131+
161132
BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks)
162133

134+
function BlockArrays.blocklengths(
135+
::Type{<:BlockedPermutation{<:Any,<:Any,Blocks}}
136+
) where {Blocks}
137+
return fieldcount.(fieldtypes(Blocks))
138+
end
139+
163140
function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...)
164141
@assert value(length) == sum(Base.length, permblocks; init=zero(Bool))
165142
blockedperm = _BlockedPermutation(permblocks)
166143
@assert isperm(blockedperm)
167144
return blockedperm
168145
end
169146

147+
#
148+
# ============================== BlockedTrivialPermutation ===============================
149+
#
170150
trivialperm(length::Union{Integer,Val}) = ntuple(identity, length)
171151

172152
struct BlockedTrivialPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}} <:
@@ -180,6 +160,8 @@ struct BlockedTrivialPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockL
180160
end
181161
end
182162

163+
Base.Tuple(blockedperm::BlockedTrivialPermutation) = flatten_tuples(blocks(blockedperm))
164+
183165
BlockArrays.blocks(blockedperm::BlockedTrivialPermutation) = getfield(blockedperm, :blocks)
184166

185167
function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}})

test/test_basics.jl

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,13 @@
1-
using BlockArrays: blockfirsts, blocklasts, blocklength, blocklengths, blocks
2-
using Combinatorics: permutations
31
using EllipsisNotation: var".."
42
using LinearAlgebra: norm, qr
5-
using TensorAlgebra: TensorAlgebra, blockedperm, blockedperm_indexin, fusedims, splitdims
3+
using TensorAlgebra, TensorAlgebra, fusedims, splitdims
64
# TODO: Remove dependency on NDTensors, create a GPUTestUtils.jl package.
75
using NDTensors: NDTensors
86
include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl"))
97
using .NDTensorsTestUtils: default_rtol
108
using Test: @test, @test_broken, @testset
119
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
12-
@testset "BlockedPermutation" begin
13-
p = blockedperm((3, 4, 5), (2, 1))
14-
@test Tuple(p) === (3, 4, 5, 2, 1)
15-
@test isperm(p)
16-
@test length(p) == 5
17-
@test blocks(p) == ((3, 4, 5), (2, 1))
18-
@test blocklength(p) == 2
19-
@test blocklengths(p) == (3, 2)
20-
@test blockfirsts(p) == (1, 4)
21-
@test blocklasts(p) == (3, 5)
22-
@test invperm(p) == blockedperm((5, 4, 1), (2, 3))
2310

24-
# Empty block.
25-
p = blockedperm((3, 2), (), (1,))
26-
@test Tuple(p) === (3, 2, 1)
27-
@test isperm(p)
28-
@test length(p) == 3
29-
@test blocks(p) == ((3, 2), (), (1,))
30-
@test blocklength(p) == 3
31-
@test blocklengths(p) == (2, 0, 1)
32-
@test blockfirsts(p) == (1, 3, 3)
33-
@test blocklasts(p) == (2, 2, 3)
34-
@test invperm(p) == blockedperm((3, 2), (), (1,))
35-
36-
# Split collection into `BlockedPermutation`.
37-
p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d"))
38-
@test p == blockedperm((3, 1), (2, 4))
39-
40-
# Singleton dimensions.
41-
p = blockedperm((2, 3), 1)
42-
@test p == blockedperm((2, 3), (1,))
43-
44-
# First dimensions are unspecified.
45-
p = blockedperm(.., (4, 3))
46-
@test p == blockedperm(1, 2, (4, 3))
47-
# Specify length
48-
p = blockedperm(.., (4, 3); length=Val(6))
49-
@test p == blockedperm(1, 2, 5, 6, (4, 3))
50-
51-
# Last dimensions are unspecified.
52-
p = blockedperm((4, 3), ..)
53-
@test p == blockedperm((4, 3), 1, 2)
54-
# Specify length
55-
p = blockedperm((4, 3), ..; length=Val(6))
56-
@test p == blockedperm((4, 3), 1, 2, 5, 6)
57-
58-
# Middle dimensions are unspecified.
59-
p = blockedperm((4, 3), .., 1)
60-
@test p == blockedperm((4, 3), 2, 1)
61-
# Specify length
62-
p = blockedperm((4, 3), .., 1; length=Val(6))
63-
@test p == blockedperm((4, 3), 2, 5, 6, 1)
64-
65-
# No dimensions are unspecified.
66-
p = blockedperm((3, 2), .., 1)
67-
@test p == blockedperm((3, 2), 1)
68-
end
6911
@testset "TensorAlgebra" begin
7012
@testset "fusedims (eltype=$elt)" for elt in elts
7113
a = randn(elt, 2, 3, 4, 5)

test/test_blockedpermutation.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using Test: @test, @test_broken, @testset
2+
3+
using BlockArrays: blockfirsts, blocklasts, blocklength, blocklengths, blocks
4+
using Combinatorics: permutations
5+
6+
using TensorAlgebra: BlockedTuple, blockedperm, blockedperm_indexin
7+
8+
@testset "BlockedPermutation" begin
9+
p = blockedperm((3, 4, 5), (2, 1))
10+
@test Tuple(p) === (3, 4, 5, 2, 1)
11+
@test isperm(p)
12+
@test length(p) == 5
13+
@test blocks(p) == ((3, 4, 5), (2, 1))
14+
@test blocklength(p) == 2
15+
@test blocklengths(p) == (3, 2)
16+
@test blockfirsts(p) == (1, 4)
17+
@test blocklasts(p) == (3, 5)
18+
@test invperm(p) == blockedperm((5, 4, 1), (2, 3))
19+
20+
# Empty block.
21+
p = blockedperm((3, 2), (), (1,))
22+
@test Tuple(p) === (3, 2, 1)
23+
@test isperm(p)
24+
@test length(p) == 3
25+
@test blocks(p) == ((3, 2), (), (1,))
26+
@test blocklength(p) == 3
27+
@test blocklengths(p) == (2, 0, 1)
28+
@test blockfirsts(p) == (1, 3, 3)
29+
@test blocklasts(p) == (2, 2, 3)
30+
@test invperm(p) == blockedperm((3, 2), (), (1,))
31+
@test BlockedTuple(p) == BlockedTuple{(2, 0, 1)}((3, 2, 1))
32+
33+
# Split collection into `BlockedPermutation`.
34+
p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d"))
35+
@test p == blockedperm((3, 1), (2, 4))
36+
37+
# Singleton dimensions.
38+
p = blockedperm((2, 3), 1)
39+
@test p == blockedperm((2, 3), (1,))
40+
41+
# First dimensions are unspecified.
42+
p = blockedperm(.., (4, 3))
43+
@test p == blockedperm(1, 2, (4, 3))
44+
# Specify length
45+
p = blockedperm(.., (4, 3); length=Val(6))
46+
@test p == blockedperm(1, 2, 5, 6, (4, 3))
47+
48+
# Last dimensions are unspecified.
49+
p = blockedperm((4, 3), ..)
50+
@test p == blockedperm((4, 3), 1, 2)
51+
# Specify length
52+
p = blockedperm((4, 3), ..; length=Val(6))
53+
@test p == blockedperm((4, 3), 1, 2, 5, 6)
54+
55+
# Middle dimensions are unspecified.
56+
p = blockedperm((4, 3), .., 1)
57+
@test p == blockedperm((4, 3), 2, 1)
58+
# Specify length
59+
p = blockedperm((4, 3), .., 1; length=Val(6))
60+
@test p == blockedperm((4, 3), 2, 5, 6, 1)
61+
62+
# No dimensions are unspecified.
63+
p = blockedperm((3, 2), .., 1)
64+
@test p == blockedperm((3, 2), 1)
65+
end

0 commit comments

Comments
 (0)