Skip to content

Commit 9dd6cab

Browse files
committed
rewrite with invbiperm. BlockArray fails.
1 parent a1344e8 commit 9dd6cab

File tree

6 files changed

+73
-98
lines changed

6 files changed

+73
-98
lines changed

src/contract/allocate_output.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
# i.e. `ContractAdd`?
1818
function output_axes(
1919
::typeof(contract),
20-
biperm_out::AbstractBlockPermutation{2},
20+
biperm_a12_to_dest::AbstractBlockPermutation{2},
2121
a1::AbstractArray,
2222
biperm1::AbstractBlockPermutation{2},
2323
a2::AbstractArray,
@@ -28,21 +28,21 @@ function output_axes(
2828
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
2929
@assert axes_contracted == axes_contracted2
3030
# default: flatten biperm_out
31-
return genperm((axes_codomain..., axes_domain...), Tuple(biperm_out))
31+
return genperm((axes_codomain..., axes_domain...), Tuple(biperm_a12_to_dest))
3232
end
3333

3434
# TODO: Use `ArrayLayouts`-like `MulAdd` object,
3535
# i.e. `ContractAdd`?
3636
function allocate_output(
3737
::typeof(contract),
38-
biperm_out::AbstractBlockPermutation,
38+
biperm_a12_to_dest::AbstractBlockPermutation,
3939
a1::AbstractArray,
4040
biperm1::AbstractBlockPermutation,
4141
a2::AbstractArray,
4242
biperm2::AbstractBlockPermutation,
4343
α::Number=one(Bool),
4444
)
4545
check_input(contract, a1, biperm1, a2, biperm2)
46-
axes_dest = output_axes(contract, biperm_out, a1, biperm1, a2, biperm2, α)
46+
axes_dest = output_axes(contract, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α)
4747
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
4848
end

src/contract/blockedperms.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
using .BaseExtensions: BaseExtensions
22
using BlockArrays: blocklengths
33

4-
function blockedperms(
5-
f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2
6-
)
7-
return blockedperms(f, dimnames_dest, dimnames1, dimnames2)
8-
end
4+
# default: if no bipartion is specified, all axes to domain
5+
invbiperm(perm, ::Any) = invbiperm(perm, Val(0))
6+
invbiperm(perm, t::Tuple{Tuple,Tuple}) = invbiperm(perm, tuplemortar(t))
7+
invbiperm(perm, t::AbstractBlockTuple{2}) = invbiperm(perm, Val(first(blocklength(t))))
98

109
function invbiperm(perm, ::Val{N1}) where {N1}
1110
perm_out = invperm(Tuple(perm))
11+
length(perm) <= N1 && return blockedpermvcat(perm_out, ())
1212
return blockedpermvcat(perm_out[begin:N1], (perm_out[(N1 + 1):end]))
1313
end
1414

15-
# codomain <-- domain
16-
function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
17-
# default: if no bipartion is specified, all axes to domain
18-
dimnames_dest_bt = tuplemortar(((), Tuple(dimnames_dest)))
19-
return blockedperms(contract, dimnames_dest_bt, dimnames1, dimnames2)
15+
function blockedperms(
16+
f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2
17+
)
18+
return blockedperms(f, dimnames_dest, dimnames1, dimnames2)
2019
end
2120

22-
function blockedperms(::typeof(contract), dimnames_dest::BlockedTuple, dimnames1, dimnames2)
21+
# codomain <-- domain
22+
function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
2323
dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2)))
2424
for i in unique(dimnames)
2525
count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels"))
@@ -31,9 +31,8 @@ function blockedperms(::typeof(contract), dimnames_dest::BlockedTuple, dimnames1
3131

3232
perm_codomain_dest = BaseExtensions.indexin(codomain, dimnames_dest)
3333
perm_domain_dest = BaseExtensions.indexin(domain, dimnames_dest)
34-
biperm_out = invbiperm(
35-
(perm_codomain_dest..., perm_domain_dest...), Val(first(blocklengths(dimnames_dest)))
36-
)
34+
biperm_dest_to_a12 = (perm_codomain_dest..., perm_domain_dest...)
35+
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, dimnames_dest)
3736

3837
perm_codomain1 = BaseExtensions.indexin(codomain, dimnames1)
3938
perm_domain1 = BaseExtensions.indexin(contracted, dimnames1)
@@ -45,5 +44,5 @@ function blockedperms(::typeof(contract), dimnames_dest::BlockedTuple, dimnames1
4544
biperm1 = blockedpermvcat(permblocks1...)
4645
permblocks2 = (perm_codomain2, perm_domain2)
4746
biperm2 = blockedpermvcat(permblocks2...)
48-
return biperm_out, biperm1, biperm2
47+
return biperm_a12_to_dest, biperm1, biperm2
4948
end

src/contract/contract.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ default_contract_alg() = Matricize()
1313
function contract!(
1414
alg::Algorithm,
1515
a_dest::AbstractArray,
16-
biperm_out::AbstractBlockPermutation,
16+
biperm_a12_to_dest::AbstractBlockPermutation,
1717
a1::AbstractArray,
1818
biperm1::AbstractBlockPermutation,
1919
a2::AbstractArray,
@@ -89,8 +89,10 @@ function contract(
8989
kwargs...,
9090
)
9191
check_input(contract, a1, labels1, a2, labels2)
92-
biperm_out, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
93-
return contract(alg, biperm_out, a1, biperm1, a2, biperm2, α; kwargs...)
92+
biperm_a12_to_dest, biperm1, biperm2 = blockedperms(
93+
contract, labels_dest, labels1, labels2
94+
)
95+
return contract(alg, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α; kwargs...)
9496
end
9597

9698
function contract!(
@@ -106,13 +108,17 @@ function contract!(
106108
kwargs...,
107109
)
108110
check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2)
109-
biperm_out, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
110-
return contract!(alg, a_dest, biperm_out, a1, biperm1, a2, biperm2, α, β; kwargs...)
111+
biperm_a12_to_dest, biperm1, biperm2 = blockedperms(
112+
contract, labels_dest, labels1, labels2
113+
)
114+
return contract!(
115+
alg, a_dest, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α, β; kwargs...
116+
)
111117
end
112118

113119
function contract(
114120
alg::Algorithm,
115-
biperm_out::AbstractBlockPermutation,
121+
biperm_a12_to_dest::AbstractBlockPermutation,
116122
a1::AbstractArray,
117123
biperm1::AbstractBlockPermutation,
118124
a2::AbstractArray,
@@ -121,7 +127,9 @@ function contract(
121127
kwargs...,
122128
)
123129
check_input(contract, a1, biperm1, a2, biperm2)
124-
a_dest = allocate_output(contract, biperm_out, a1, biperm1, a2, biperm2, α)
125-
contract!(alg, a_dest, biperm_out, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
130+
a_dest = allocate_output(contract, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α)
131+
contract!(
132+
alg, a_dest, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...
133+
)
126134
return a_dest
127135
end
Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,21 @@
11
using LinearAlgebra: mul!
22

3-
function isinplace(::AbstractArray, biperm_out)
4-
return istrivialperm(Tuple(biperm_out))
5-
end
6-
73
function contract!(
8-
alg::Matricize,
9-
a_dest::AbstractArray,
10-
biperm_out::AbstractBlockPermutation{2},
11-
a1::AbstractArray,
12-
biperm1::AbstractBlockPermutation{2},
13-
a2::AbstractArray,
14-
biperm2::AbstractBlockPermutation{2},
15-
α::Number,
16-
β::Number,
17-
)
18-
if isinplace(a_dest, biperm_out)
19-
return contract_inplace!(alg, a_dest, biperm_out, a1, biperm1, a2, biperm2, α, β)
20-
else
21-
return contract_outofplace!(alg, a_dest, biperm_out, a1, biperm1, a2, biperm2, α, β)
22-
end
23-
end
24-
25-
function contract_inplace!(
26-
::Matricize,
27-
a_dest::AbstractArray,
28-
biperm_out::AbstractBlockPermutation{2},
29-
a1::AbstractArray,
30-
biperm1::AbstractBlockPermutation{2},
31-
a2::AbstractArray,
32-
biperm2::AbstractBlockPermutation{2},
33-
α::Number,
34-
β::Number,
35-
)
36-
biperm_dest = invbiperm(biperm_out, Val(first(blocklengths(biperm1))))
37-
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
38-
a_dest_mat = matricize(a_dest, biperm_dest; copy=false)
39-
a1_mat = matricize(a1, biperm1; copy=false)
40-
a2_mat = matricize(a2, biperm2; copy=false)
41-
mul!(a_dest_mat, a1_mat, a2_mat, α, β)
42-
return a_dest
43-
end
44-
45-
function contract_outofplace!(
464
::Matricize,
475
a_dest::AbstractArray,
48-
biperm_out::AbstractBlockPermutation{2},
6+
biperm_a12_to_dest::AbstractBlockPermutation{2},
497
a1::AbstractArray,
508
biperm1::AbstractBlockPermutation{2},
519
a2::AbstractArray,
5210
biperm2::AbstractBlockPermutation{2},
5311
α::Number,
5412
β::Number,
5513
)
56-
biperm_dest = invbiperm(biperm_out, Val(first(blocklengths(biperm1))))
57-
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
14+
biperm_dest_to_a12 = invbiperm(biperm_a12_to_dest, Val(first(blocklengths(biperm1))))
15+
check_input(contract, a_dest, biperm_dest_to_a12, a1, biperm1, a2, biperm2)
5816
a1_mat = matricize(a1, biperm1)
5917
a2_mat = matricize(a2, biperm2)
6018
a_dest_mat = a1_mat * a2_mat
61-
unmatricize_add!(a_dest, a_dest_mat, biperm_dest, α, β)
19+
unmatricize_add!(a_dest, a_dest_mat, biperm_dest_to_a12, α, β)
6220
return a_dest
6321
end

src/matricize.jl

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,20 @@ end
4545
# matrix factorizations assume copy
4646
# maybe: copy=false kwarg
4747

48-
function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2}; copy=false)
48+
function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2})
4949
ndims(a) == length(biperm) || throw(ArgumentError("Invalid bipermutation"))
50-
return matricize(FusionStyle(a), a, biperm; copy)
50+
return matricize(FusionStyle(a), a, biperm)
5151
end
5252

5353
function matricize(
54-
style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2}; copy=false
54+
style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2}
5555
)
56-
if istrivialperm(Tuple(biperm)) && !copy
57-
return matricize(style, a, trivialperm(biperm))
58-
end
5956
a_perm = permuteblockeddims(a, biperm)
6057
return matricize(style, a_perm, trivialperm(biperm))
6158
end
6259

6360
function matricize(
64-
style::FusionStyle, a::AbstractArray, biperm::BlockedTrivialPermutation{2}; copy=false
61+
style::FusionStyle, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
6562
)
6663
return throw(MethodError(matricize, Tuple{typeof(style),typeof(a),typeof(biperm)}))
6764
end
@@ -72,22 +69,29 @@ function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPerm
7269
return reshape(a, new_axes...)
7370
end
7471

75-
function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple; copy=false)
76-
return matricize(a, blockedpermvcat(permblock1, permblock2; length=Val(ndims(a))); copy)
72+
function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple)
73+
return matricize(a, blockedpermvcat(permblock1, permblock2; length=Val(ndims(a))))
7774
end
7875

7976
# ==================================== unmatricize =======================================
80-
function unmatricize(m::AbstractMatrix, axes, biperm::AbstractBlockPermutation{2})
81-
length(axes) == length(biperm) || throw(ArgumentError("axes do not match permutation"))
82-
return unmatricize(FusionStyle(m), m, axes, biperm)
77+
function unmatricize(
78+
m::AbstractMatrix, axes_dest, biperm_dest_to_a12::AbstractBlockPermutation{2}
79+
)
80+
length(axes_dest) == length(biperm_dest_to_a12) ||
81+
throw(ArgumentError("axes do not match permutation"))
82+
return unmatricize(FusionStyle(m), m, axes_dest, biperm_dest_to_a12)
8383
end
8484

8585
function unmatricize(
86-
::FusionStyle, m::AbstractMatrix, axes, biperm::AbstractBlockPermutation{2}
86+
::FusionStyle,
87+
m::AbstractMatrix,
88+
axes_dest,
89+
biperm_dest_to_a12::AbstractBlockPermutation{2},
8790
)
88-
blocked_axes = axes[biperm]
89-
a_perm = unmatricize(m, blocked_axes)
90-
return permuteblockeddims(a_perm, invperm(biperm))
91+
blocked_axes = axes_dest[biperm_dest_to_a12]
92+
a12 = unmatricize(m, blocked_axes)
93+
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes_dest)
94+
return permuteblockeddims(a12, biperm_a12_to_dest)
9195
end
9296

9397
function unmatricize(
@@ -111,14 +115,19 @@ function unmatricize(
111115
return unmatricize(m, blocked_axes)
112116
end
113117

114-
function unmatricize!(a, m::AbstractMatrix, biperm::AbstractBlockPermutation{2})
115-
ndims(a) == length(biperm) ||
118+
function unmatricize!(
119+
a_dest, m::AbstractMatrix, biperm_dest_to_a12::AbstractBlockPermutation{2}
120+
)
121+
ndims(a_dest) == length(biperm_dest_to_a12) ||
116122
throw(ArgumentError("destination does not match permutation"))
117-
blocked_axes = axes(a)[biperm]
123+
blocked_axes = axes(a_dest)[biperm_dest_to_a12]
118124
a_perm = unmatricize(m, blocked_axes)
119-
return permuteblockeddims!(a, a_perm, invperm(biperm))
125+
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes(a_dest))
126+
return permuteblockeddims!(a_dest, a_perm, biperm_a12_to_dest)
120127
end
121128

122-
function unmatricize_add!(a_dest, a_dest_mat, biperm_dest, α, β)
123-
return mul!(a_dest, 1.0, unmatricize(a_dest_mat, axes(a_dest), biperm_dest), α, β)
129+
function unmatricize_add!(a_dest, a_dest_mat, biperm_dest_to_a12, α, β)
130+
a12 = unmatricize(a_dest_mat, axes(a_dest), biperm_dest_to_a12)
131+
a_dest .= α .* a12 .+ β .* a_dest
132+
return a_dest
124133
end

test/test_basics.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
9595
@test a a0
9696

9797
bp = blockedpermvcat((4, 2), (1, 3))
98-
a = unmatricize(m, map(i -> axes0[i], invperm(Tuple(bp))), bp)
98+
bpinv = blockedpermvcat((3, 2), (4, 1))
99+
a = unmatricize(m, map(i -> axes0[i], bp), bpinv)
99100
@test eltype(a) === elt
100-
@test a permutedims(a0, invperm(Tuple(bp)))
101+
@test a permutedims(a0, Tuple(bp))
101102

102103
a = similar(a0)
103104
unmatricize!(a, m, blockedpermvcat((1, 2), (3, 4)))
@@ -109,7 +110,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
109110

110111
a1 = permutedims(a0, Tuple(bp))
111112
a = similar(a1)
112-
unmatricize!(a, m, invperm(bp))
113+
unmatricize!(a, m, bpinv)
113114
@test a a1
114115

115116
a = unmatricize(m, (), axes0)

0 commit comments

Comments
 (0)