Skip to content

Commit 5d9371f

Browse files
committed
Introduce bipermutedims
1 parent c468c4d commit 5d9371f

File tree

4 files changed

+52
-17
lines changed

4 files changed

+52
-17
lines changed

src/factorizations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ for f in (
3636
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
3737
kwargs...,
3838
)
39-
A_perm = permuteblockeddims(A, codomain_perm, domain_perm)
39+
A_perm = bipermutedims(A, codomain_perm, domain_perm)
4040
return $f(A_perm, Val(length(codomain_perm)), Val(length(domain_perm)); kwargs...)
4141
end
4242
function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)

src/matricize.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,34 @@ function fuseaxes(
3939
end
4040

4141
# Inner version takes a list of sub-permutations, overload this one if needed.
42+
# TODO: Remove _permutedims once support for Julia 1.10 is dropped
43+
# define permutedims with a BlockedPermuation. Default is to flatten it.
44+
# TODO: Deprecate `permuteblockeddims` in favor of `bipermutedims`.
45+
# Keeping it here for backwards compatibility.
46+
function bipermutedims(a::AbstractArray, perm1, perm2)
47+
return permuteblockeddims(a, perm1, perm2)
48+
end
49+
function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2)
50+
return permuteblockeddims!(a_dest, a_src, perm1, perm2)
51+
end
52+
function bipermutedims(a::AbstractArray, biperm::AbstractBlockPermutation{2})
53+
return permuteblockeddims(a, biperm)
54+
end
55+
function bipermutedims!(
56+
a_dest::AbstractArray, a_src::AbstractArray, biperm::AbstractBlockPermutation{2}
57+
)
58+
return permuteblockeddims!(a_dest, a_src, biperm)
59+
end
60+
61+
# Older interface.
62+
# TODO: Deprecate in favor of `bipermutedims` (or decide if we want to keep it
63+
# in case there are applications of more general partitionings).
4264
function permuteblockeddims(a::AbstractArray, perm1, perm2)
4365
return _permutedims(a, (perm1..., perm2...))
4466
end
4567
function permuteblockeddims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2)
4668
return _permutedims!(a_dest, a_src, (perm1..., perm2...))
4769
end
48-
49-
# TODO remove _permutedims once support for Julia 1.10 is dropped
50-
# define permutedims with a BlockedPermuation. Default is to flatten it.
5170
function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation{2})
5271
return permuteblockeddims(a, blocks(biperm)...)
5372
end
@@ -87,7 +106,7 @@ function matricize(
87106
) where {N1, N2}
88107
ndims(a) == length(permblock1) + length(permblock2) ||
89108
throw(ArgumentError("Invalid bipermutation"))
90-
a_perm = permuteblockeddims(a, permblock1, permblock2)
109+
a_perm = bipermutedims(a, permblock1, permblock2)
91110
return matricize(style, a_perm, Val(length(permblock1)), Val(length(permblock2)))
92111
end
93112

@@ -179,7 +198,7 @@ function unmatricize(
179198
blocked_axes = axes_dest[invbiperm]
180199
a12 = unmatricize(style, m, blocked_axes)
181200
biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest))
182-
return permuteblockeddims(a12, biperm_dest)
201+
return bipermutedims(a12, biperm_dest)
183202
end
184203

185204
function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2})
@@ -208,7 +227,7 @@ function unmatricize!(
208227
blocked_axes = axes(a_dest)[invbiperm]
209228
a_perm = unmatricize(style, m, blocked_axes)
210229
biperm_dest = biperm(invperm(invbiperm), length_codomain(axes(a_dest)))
211-
return permuteblockeddims!(a_dest, a_perm, biperm_dest)
230+
return bipermutedims!(a_dest, a_perm, biperm_dest)
212231
end
213232

214233
function unmatricize!(

src/matrixfunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ for f in MATRIX_FUNCTIONS
4848
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
4949
kwargs...,
5050
)
51-
a_perm = permuteblockeddims(a, codomain_perm, domain_perm)
51+
a_perm = bipermutedims(a, codomain_perm, domain_perm)
5252
return $f(a_perm, Val(length(codomain_perm)), Val(length(domain_perm)); kwargs...)
5353
end
5454
function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...)

test/test_basics.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ using TensorAlgebra:
1414
length_codomain,
1515
length_domain,
1616
matricize,
17+
bipermutedims,
18+
bipermutedims!,
1719
permuteblockeddims,
1820
permuteblockeddims!,
1921
tuplemortar,
@@ -33,15 +35,29 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3335
@test length_domain(bt) == 1
3436
end
3537

36-
@testset "permuteblockeddims (eltype=$elt)" for elt in elts
37-
a = randn(elt, 2, 3, 4, 5)
38-
a_perm = permuteblockeddims(a, blockedpermvcat((3, 1), (2, 4)))
39-
@test a_perm == permutedims(a, (3, 1, 2, 4))
40-
41-
a = randn(elt, 2, 3, 4, 5)
42-
a_perm = Array{elt}(undef, (4, 2, 3, 5))
43-
permuteblockeddims!(a_perm, a, blockedpermvcat((3, 1), (2, 4)))
44-
@test a_perm == permutedims(a, (3, 1, 2, 4))
38+
@testset "bipermutedims/permuteblockeddims (eltype=$elt)" for f in
39+
(:bipermutedims, :permuteblockeddims),
40+
elt in elts
41+
f! = Symbol(f, :!)
42+
@eval begin
43+
a = randn($elt, 2, 3, 4, 5)
44+
a_perm = $f(a, blockedpermvcat((3, 1), (2, 4)))
45+
@test a_perm == permutedims(a, (3, 1, 2, 4))
46+
47+
a = randn($elt, 2, 3, 4, 5)
48+
a_perm = $f(a, (3, 1), (2, 4))
49+
@test a_perm == permutedims(a, (3, 1, 2, 4))
50+
51+
a = randn($elt, 2, 3, 4, 5)
52+
a_perm = Array{$elt}(undef, (4, 2, 3, 5))
53+
$f!(a_perm, a, blockedpermvcat((3, 1), (2, 4)))
54+
@test a_perm == permutedims(a, (3, 1, 2, 4))
55+
56+
a = randn($elt, 2, 3, 4, 5)
57+
a_perm = Array{$elt}(undef, (4, 2, 3, 5))
58+
$f!(a_perm, a, (3, 1), (2, 4))
59+
@test a_perm == permutedims(a, (3, 1, 2, 4))
60+
end
4561
end
4662
@testset "matricize (eltype=$elt)" for elt in elts
4763
a = randn(elt, 2, 3, 4, 5)

0 commit comments

Comments
 (0)