Skip to content

Commit b185c8e

Browse files
committed
remove invbiperm
1 parent 26c9765 commit b185c8e

File tree

3 files changed

+22
-22
lines changed

3 files changed

+22
-22
lines changed

src/contract/blockedperms.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@ using .BaseExtensions: BaseExtensions
22
using BlockArrays: blocklengths
33

44
# default: if no bipartion is specified, all axes to domain
5-
invbiperm(perm, ::Any) = invbiperm(perm, Val(0))
6-
invbiperm(perm, t::Tuple{Tuple,Tuple}) = invbiperm(perm, tuplemortar(t))
7-
invbiperm(perm, t::AbstractBlockTuple{2}) = invbiperm(perm, Val(first(blocklength(t))))
8-
9-
function invbiperm(perm, ::Val{N1}) where {N1}
10-
perm_out = invperm(Tuple(perm))
11-
length(perm) <= N1 && return blockedpermvcat(perm_out, ())
12-
return blockedpermvcat(perm_out[begin:N1], (perm_out[(N1 + 1):end]))
5+
function biperm(perm, blocklength1::Integer)
6+
return biperm(perm, Val(blocklength1))
137
end
8+
function biperm(perm, ::Val{BlockLength1}) where {BlockLength1}
9+
length(perm) < BlockLength1 && throw(ArgumentError("Invalid codomain length"))
10+
return blockedperm(Tuple(perm), (BlockLength1, length(perm) - BlockLength1))
11+
end
12+
13+
length_codomain(t::AbstractBlockTuple{2}) = first(blocklengths(t))
14+
# Assume all dimensions are in the domain by default
15+
length_codomain(t) = 0
1416

1517
function blockedperms(
1618
f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2
@@ -32,7 +34,7 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
3234
perm_codomain_dest = BaseExtensions.indexin(codomain, dimnames_dest)
3335
perm_domain_dest = BaseExtensions.indexin(domain, dimnames_dest)
3436
biperm_dest_to_a12 = (perm_codomain_dest..., perm_domain_dest...)
35-
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, dimnames_dest)
37+
biperm_a12_to_dest = biperm(invperm(biperm_dest_to_a12), length_codomain(dimnames_dest))
3638

3739
perm_codomain1 = BaseExtensions.indexin(codomain, dimnames1)
3840
perm_domain1 = BaseExtensions.indexin(contracted, dimnames1)

src/contract/contract_matricize/contract.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ function contract!(
1111
α::Number,
1212
β::Number,
1313
)
14-
biperm_dest_to_a12 = invbiperm(biperm_a12_to_dest, Val(first(blocklengths(biperm1))))
14+
biperm_dest_to_a12 = biperm(invperm(biperm_a12_to_dest), length_codomain(biperm1))
15+
1516
check_input(contract, a_dest, biperm_dest_to_a12, a1, biperm1, a2, biperm2)
1617
a1_mat = matricize(a1, biperm1)
1718
a2_mat = matricize(a2, biperm2)

src/matricize.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,19 @@ function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple)
7474
end
7575

7676
# ==================================== unmatricize =======================================
77-
function unmatricize(
78-
m::AbstractMatrix, axes_dest, biperm_dest_to_a12::AbstractBlockPermutation{2}
79-
)
80-
length(axes_dest) == length(biperm_dest_to_a12) ||
77+
function unmatricize(m::AbstractMatrix, axes, biperm_dest::AbstractBlockPermutation{2})
78+
length(axes) == length(biperm_dest) ||
8179
throw(ArgumentError("axes do not match permutation"))
82-
return unmatricize(FusionStyle(m), m, axes_dest, biperm_dest_to_a12)
80+
return unmatricize(FusionStyle(m), m, axes, biperm_dest)
8381
end
8482

8583
function unmatricize(
86-
::FusionStyle,
87-
m::AbstractMatrix,
88-
axes_dest,
89-
biperm_dest_to_a12::AbstractBlockPermutation{2},
84+
::FusionStyle, m::AbstractMatrix, axes, biperm_dest_to_a12::AbstractBlockPermutation{2}
9085
)
91-
blocked_axes = axes_dest[biperm_dest_to_a12]
86+
blocked_axes = axes[biperm_dest_to_a12]
9287
a12 = unmatricize(m, blocked_axes)
93-
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes_dest)
88+
biperm_a12_to_dest = biperm(invperm(biperm_dest_to_a12), length_codomain(axes))
89+
9490
return permuteblockeddims(a12, biperm_a12_to_dest)
9591
end
9692

@@ -122,7 +118,8 @@ function unmatricize!(
122118
throw(ArgumentError("destination does not match permutation"))
123119
blocked_axes = axes(a_dest)[biperm_dest_to_a12]
124120
a_perm = unmatricize(m, blocked_axes)
125-
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes(a_dest))
121+
biperm_a12_to_dest = biperm(invperm(biperm_dest_to_a12), length_codomain(axes(a_dest)))
122+
126123
return permuteblockeddims!(a_dest, a_perm, biperm_a12_to_dest)
127124
end
128125

0 commit comments

Comments
 (0)