Skip to content

Commit 9307831

Browse files
committed
test BlockedTrivialPermutation
1 parent ab98cf3 commit 9307831

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

src/blockedpermutation.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwar
8383
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
8484
end
8585

86+
function blockedperm(bt::AbstractBlockTuple)
87+
return blockedperm(Val(length(bt)), blocks(bt)...)
88+
end
89+
8690
function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
8791
return maximum(specified_perm)
8892
end
@@ -164,11 +168,19 @@ Base.Tuple(blockedperm::BlockedTrivialPermutation) = flatten_tuples(blocks(block
164168

165169
BlockArrays.blocks(blockedperm::BlockedTrivialPermutation) = getfield(blockedperm, :blocks)
166170

171+
function BlockArrays.blocklengths(
172+
::Type{<:BlockedTrivialPermutation{<:Any,<:Any,Blocks}}
173+
) where {Blocks}
174+
return fieldcount.(fieldtypes(Blocks))
175+
end
176+
177+
blockedperm(tp::BlockedTrivialPermutation) = tp
178+
167179
function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}})
168180
return _BlockedTrivialPermutation(blocklengths)
169181
end
170182

171-
function trivialperm(blockedperm::AbstractBlockedPermutation)
183+
function trivialperm(blockedperm::AbstractBlockTuple)
172184
return blockedtrivialperm(blocklengths(blockedperm))
173185
end
174186
Base.invperm(blockedperm::BlockedTrivialPermutation) = blockedperm

test/test_basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using EllipsisNotation: var".."
22
using LinearAlgebra: norm, qr
3-
using TensorAlgebra, TensorAlgebra, fusedims, splitdims
3+
using TensorAlgebra: TensorAlgebra, fusedims, splitdims
44
# TODO: Remove dependency on NDTensors, create a GPUTestUtils.jl package.
55
using NDTensors: NDTensors
66
include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl"))

test/test_blockedpermutation.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ using Test: @test, @test_broken, @testset
22

33
using BlockArrays: blockfirsts, blocklasts, blocklength, blocklengths, blocks
44
using Combinatorics: permutations
5+
using EllipsisNotation: var".."
6+
using TestExtras: @constinferred
57

6-
using TensorAlgebra: BlockedTuple, blockedperm, blockedperm_indexin
8+
using TensorAlgebra:
9+
BlockedTrivialPermutation, BlockedTuple, blockedperm, blockedperm_indexin, trivialperm
710

811
@testset "BlockedPermutation" begin
912
p = blockedperm((3, 4, 5), (2, 1))
@@ -28,7 +31,13 @@ using TensorAlgebra: BlockedTuple, blockedperm, blockedperm_indexin
2831
@test blockfirsts(p) == (1, 3, 3)
2932
@test blocklasts(p) == (2, 2, 3)
3033
@test invperm(p) == blockedperm((3, 2), (), (1,))
31-
@test BlockedTuple(p) == BlockedTuple{(2, 0, 1)}((3, 2, 1))
34+
35+
bt = BlockedTuple{(2, 0, 1)}((3, 2, 1))
36+
@test (@constinferred BlockedTuple(p)) == bt
37+
@test (@constinferred map(identity, p)) == bt
38+
@test (@constinferred p .+ p) == BlockedTuple{(2, 0, 1)}((6, 4, 2))
39+
@test (@constinferred blockedperm(p)) == p
40+
@test (@constinferred blockedperm(bt)) == p
3241

3342
# Split collection into `BlockedPermutation`.
3443
p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d"))
@@ -63,3 +72,22 @@ using TensorAlgebra: BlockedTuple, blockedperm, blockedperm_indexin
6372
p = blockedperm((3, 2), .., 1)
6473
@test p == blockedperm((3, 2), 1)
6574
end
75+
76+
@testset "BlockedTrivialPermutation" begin
77+
p = blockedperm((3, 2), (), (1,))
78+
tp = trivialperm(p)
79+
80+
@test tp isa BlockedTrivialPermutation
81+
@test Tuple(tp) == (1, 2, 3)
82+
@test blocklength(tp) == 3
83+
@test blocklengths(tp) == (2, 0, 1)
84+
85+
bt = BlockedTuple{(2, 0, 1)}((1, 2, 3))
86+
@test (@constinferred BlockedTuple(tp)) == bt
87+
@test (@constinferred blocks(tp)) == blocks(bt)
88+
@test (@constinferred map(identity, tp)) == bt
89+
@test (@constinferred tp .+ tp) == BlockedTuple{(2, 0, 1)}((2, 4, 6))
90+
@test (@constinferred blockedperm(tp)) == tp
91+
@test (@constinferred trivialperm(tp)) == tp
92+
@test (@constinferred trivialperm(bt)) == tp
93+
end

0 commit comments

Comments
 (0)