Skip to content

Commit bc8a1e9

Browse files
committed
define length_domain
1 parent f28f180 commit bc8a1e9

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

src/contract/allocate_output.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
# i.e. `ContractAdd`?
1818
function output_axes(
1919
::typeof(contract),
20-
biperm_a12_to_dest::AbstractBlockPermutation{2},
20+
biperm_dest::AbstractBlockPermutation{2},
2121
a1::AbstractArray,
2222
biperm1::AbstractBlockPermutation{2},
2323
a2::AbstractArray,
@@ -26,23 +26,23 @@ function output_axes(
2626
)
2727
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
2828
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
29-
@assert axes_contracted == axes_contracted2
29+
@assert length.(axes_contracted) == length.(axes_contracted2)
3030
# default: flatten biperm_out
31-
return genperm((axes_codomain..., axes_domain...), Tuple(biperm_a12_to_dest))
31+
return genperm((axes_codomain..., axes_domain...), Tuple(biperm_dest))
3232
end
3333

3434
# TODO: Use `ArrayLayouts`-like `MulAdd` object,
3535
# i.e. `ContractAdd`?
3636
function allocate_output(
3737
::typeof(contract),
38-
biperm_a12_to_dest::AbstractBlockPermutation,
38+
biperm_dest::AbstractBlockPermutation,
3939
a1::AbstractArray,
4040
biperm1::AbstractBlockPermutation,
4141
a2::AbstractArray,
4242
biperm2::AbstractBlockPermutation,
4343
α::Number=one(Bool),
4444
)
4545
check_input(contract, a1, biperm1, a2, biperm2)
46-
axes_dest = output_axes(contract, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α)
46+
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
4747
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
4848
end

src/contract/blockedperms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ length_codomain(t::AbstractBlockTuple{2}) = first(blocklengths(t))
1414
# Assume all dimensions are in the domain by default
1515
length_codomain(t) = 0
1616

17+
length_domain(t) = length(t) - length_codomain(t)
18+
1719
function blockedperms(
1820
f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2
1921
)

src/matricize.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ function unmatricize!(a_dest, m::AbstractMatrix, invbiperm::AbstractBlockPermuta
119119
blocked_axes = axes(a_dest)[invbiperm]
120120
a_perm = unmatricize(m, blocked_axes)
121121
biperm_dest = biperm(invperm(invbiperm), length_codomain(axes(a_dest)))
122-
123122
return permuteblockeddims!(a_dest, a_perm, biperm_dest)
124123
end
125124

test/test_basics.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ using TensorAlgebra:
1212
contract,
1313
contract!,
1414
matricize,
15+
length_codomain,
16+
length_domain,
1517
tuplemortar,
1618
unmatricize,
1719
unmatricize!
@@ -20,6 +22,15 @@ default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
2022
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
2123

2224
@testset "TensorAlgebra" begin
25+
@testset "misc" begin
26+
t = (1, 2, 3)
27+
bt = tuplemortar(((1, 2), (3,)))
28+
@test length_codomain(t) == 0
29+
@test length_codomain(bt) == 2
30+
@test length_domain(t) == 3
31+
@test length_domain(bt) == 1
32+
end
33+
2334
@testset "permuteblockeddims (eltype=$elt)" for elt in elts
2435
a = randn(elt, 2, 3, 4, 5)
2536
a_perm = permuteblockeddims(a, blockedpermvcat((3, 1), (2, 4)))

0 commit comments

Comments
 (0)