Skip to content

Commit 6e71131

Browse files
committed
pass tests
1 parent 3aaed6c commit 6e71131

File tree

4 files changed

+31
-18
lines changed

4 files changed

+31
-18
lines changed

src/contract/allocate_output.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
# i.e. `ContractAdd`?
1818
function output_axes(
1919
::typeof(contract),
20-
biperm_dest::AbstractBlockPermutation{2},
20+
biperm_out::AbstractBlockPermutation{2},
2121
a1::AbstractArray,
2222
biperm1::AbstractBlockPermutation{2},
2323
a2::AbstractArray,
@@ -27,23 +27,22 @@ function output_axes(
2727
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
2828
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
2929
@assert axes_contracted == axes_contracted2
30-
return genperm((axes_codomain..., axes_domain...), invperm(Tuple(biperm_dest)))
30+
# default: flatten biperm_out
31+
return genperm((axes_codomain..., axes_domain...), Tuple(biperm_out))
3132
end
3233

3334
# TODO: Use `ArrayLayouts`-like `MulAdd` object,
3435
# i.e. `ContractAdd`?
3536
function allocate_output(
3637
::typeof(contract),
37-
biperm_dest::AbstractBlockPermutation,
38+
biperm_out::AbstractBlockPermutation,
3839
a1::AbstractArray,
3940
biperm1::AbstractBlockPermutation,
4041
a2::AbstractArray,
4142
biperm2::AbstractBlockPermutation,
4243
α::Number=one(Bool),
4344
)
4445
check_input(contract, a1, biperm1, a2, biperm2)
45-
blocklengths(biperm_dest) == (length(biperm1[Block(1)]), length(biperm2[Block(2)])) ||
46-
throw(ArgumentError("Invalid permutation for destination tensor"))
47-
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
46+
axes_dest = output_axes(contract, biperm_out, a1, biperm1, a2, biperm2, α)
4847
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
4948
end

src/contract/blockedperms.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
using .BaseExtensions: BaseExtensions
2+
using BlockArrays: blocklengths
23

34
function blockedperms(
45
f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2
56
)
67
return blockedperms(f, dimnames_dest, dimnames1, dimnames2)
78
end
89

10+
function invbiperm(perm, ::Val{N1}) where {N1}
11+
perm_out = invperm(Tuple(perm))
12+
return blockedpermvcat(perm_out[begin:N1], (perm_out[(N1 + 1):end]))
13+
end
14+
915
# codomain <-- domain
1016
function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
17+
# default: if no bipartion is specified, all axes to domain
18+
dimnames_dest_bt = tuplemortar(((), Tuple(dimnames_dest)))
19+
return blockedperms(contract, dimnames_dest_bt, dimnames1, dimnames2)
20+
end
21+
22+
function blockedperms(::typeof(contract), dimnames_dest::BlockedTuple, dimnames1, dimnames2)
1123
dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2)))
1224
for i in unique(dimnames)
1325
count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels"))
@@ -19,18 +31,19 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
1931

2032
perm_codomain_dest = BaseExtensions.indexin(codomain, dimnames_dest)
2133
perm_domain_dest = BaseExtensions.indexin(domain, dimnames_dest)
34+
biperm_out = invbiperm(
35+
(perm_codomain_dest..., perm_domain_dest...), Val(first(blocklengths(dimnames_dest)))
36+
)
2237

2338
perm_codomain1 = BaseExtensions.indexin(codomain, dimnames1)
2439
perm_domain1 = BaseExtensions.indexin(contracted, dimnames1)
2540

2641
perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2)
2742
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)
2843

29-
permblocks_dest = (perm_codomain_dest, perm_domain_dest)
30-
biperm_dest = blockedpermvcat(permblocks_dest...)
3144
permblocks1 = (perm_codomain1, perm_domain1)
3245
biperm1 = blockedpermvcat(permblocks1...)
3346
permblocks2 = (perm_codomain2, perm_domain2)
3447
biperm2 = blockedpermvcat(permblocks2...)
35-
return biperm_dest, biperm1, biperm2
48+
return biperm_out, biperm1, biperm2
3649
end

src/contract/contract.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ default_contract_alg() = Matricize()
1313
function contract!(
1414
alg::Algorithm,
1515
a_dest::AbstractArray,
16-
biperm_dest::AbstractBlockPermutation,
16+
biperm_out::AbstractBlockPermutation,
1717
a1::AbstractArray,
1818
biperm1::AbstractBlockPermutation,
1919
a2::AbstractArray,
@@ -89,8 +89,8 @@ function contract(
8989
kwargs...,
9090
)
9191
check_input(contract, a1, labels1, a2, labels2)
92-
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
93-
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...)
92+
biperm_out, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
93+
return contract(alg, biperm_out, a1, biperm1, a2, biperm2, α; kwargs...)
9494
end
9595

9696
function contract!(
@@ -106,13 +106,13 @@ function contract!(
106106
kwargs...,
107107
)
108108
check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2)
109-
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
110-
return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
109+
biperm_out, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
110+
return contract!(alg, a_dest, biperm_out, a1, biperm1, a2, biperm2, α, β; kwargs...)
111111
end
112112

113113
function contract(
114114
alg::Algorithm,
115-
biperm_dest::AbstractBlockPermutation,
115+
biperm_out::AbstractBlockPermutation,
116116
a1::AbstractArray,
117117
biperm1::AbstractBlockPermutation,
118118
a2::AbstractArray,
@@ -121,7 +121,7 @@ function contract(
121121
kwargs...,
122122
)
123123
check_input(contract, a1, biperm1, a2, biperm2)
124-
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
125-
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
124+
a_dest = allocate_output(contract, biperm_out, a1, biperm1, a2, biperm2, α)
125+
contract!(alg, a_dest, biperm_out, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
126126
return a_dest
127127
end

src/contract/contract_matricize/contract.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ using LinearAlgebra: mul!
33
function contract!(
44
::Matricize,
55
a_dest::AbstractArray,
6-
biperm_dest::AbstractBlockPermutation{2},
6+
biperm_out::AbstractBlockPermutation{2},
77
a1::AbstractArray,
88
biperm1::AbstractBlockPermutation{2},
99
a2::AbstractArray,
1010
biperm2::AbstractBlockPermutation{2},
1111
α::Number,
1212
β::Number,
1313
)
14+
biperm_dest = invbiperm(biperm_out, Val(first(blocklengths(biperm1))))
1415
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
1516
a_dest_mat = matricize(a_dest, biperm_dest)
1617
a1_mat = matricize(a1, biperm1)

0 commit comments

Comments
 (0)