Skip to content

Fix bipermutations in contract #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.11"
version = "0.3.12"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
11 changes: 5 additions & 6 deletions src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
# i.e. `ContractAdd`?
function output_axes(
::typeof(contract),
biperm_dest::AbstractBlockPermutation{2},
biperm_a12_to_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
Expand All @@ -27,23 +27,22 @@ function output_axes(
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
@assert axes_contracted == axes_contracted2
return genperm((axes_codomain..., axes_domain...), invperm(Tuple(biperm_dest)))
# default: flatten biperm_out
return genperm((axes_codomain..., axes_domain...), Tuple(biperm_a12_to_dest))
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function allocate_output(
::typeof(contract),
biperm_dest::AbstractBlockPermutation,
biperm_a12_to_dest::AbstractBlockPermutation,
a1::AbstractArray,
biperm1::AbstractBlockPermutation,
a2::AbstractArray,
biperm2::AbstractBlockPermutation,
α::Number=one(Bool),
)
check_input(contract, a1, biperm1, a2, biperm2)
blocklengths(biperm_dest) == (length(biperm1[Block(1)]), length(biperm2[Block(2)])) ||
throw(ArgumentError("Invalid permutation for destination tensor"))
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
axes_dest = output_axes(contract, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α)
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
end
18 changes: 15 additions & 3 deletions src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
using .BaseExtensions: BaseExtensions
using BlockArrays: blocklengths

# default: if no bipartion is specified, all axes to domain
invbiperm(perm, ::Any) = invbiperm(perm, Val(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit strange to me that it just allows anything as the second argument. Maybe this should be invbiperm(perm)? Is this used anywhere right now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the rest of the code, I see that it is being used in calls like biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes(a_dest)), where axes(a_dest) might output a blocked tuple or a flat tuple.

I think the invbiperm function is trying to do too much and therefore makes the code harder to understand. Instead, maybe we could introduce new functions biperm and length_codomain:

function biperm(perm, blocklength1::Int)
  return biperm(perm, Val(blocklength1))
end
function biperm(perm, ::Val{BlockLength1}) where {BlockLength1}
  # Check: BlockLength1 <= length(perm)
  return blockedperm(Tuple(perm),(BlockLength1, length(perm) - BlockLength1))
end

length_codomain(t::AbstractBlockTuple{2}) = first(blocklength(t))
# Assume all dimensions are in the domain by default
length_codomain(t) = 0

and the use a combination of invperm, biperm, and length_codomain in the contract code, for example:

biperm_a12_to_dest = biperm(invperm(biperm_dest_to_a12), length_codomain(axes(a_dest)))

invbiperm(perm, t::Tuple{Tuple,Tuple}) = invbiperm(perm, tuplemortar(t))
invbiperm(perm, t::AbstractBlockTuple{2}) = invbiperm(perm, Val(first(blocklength(t))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
invbiperm(perm, t::AbstractBlockTuple{2}) = invbiperm(perm, Val(first(blocklength(t))))
invbiperm(perm, t::AbstractBlockTuple{2}) = invbiperm(perm, Val(first(blocklengths(t))))


function invbiperm(perm, ::Val{N1}) where {N1}
perm_out = invperm(Tuple(perm))
length(perm) <= N1 && return blockedpermvcat(perm_out, ())
return blockedpermvcat(perm_out[begin:N1], (perm_out[(N1 + 1):end]))
end

function blockedperms(
f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2
Expand All @@ -19,18 +31,18 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)

perm_codomain_dest = BaseExtensions.indexin(codomain, dimnames_dest)
perm_domain_dest = BaseExtensions.indexin(domain, dimnames_dest)
biperm_dest_to_a12 = (perm_codomain_dest..., perm_domain_dest...)
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, dimnames_dest)

perm_codomain1 = BaseExtensions.indexin(codomain, dimnames1)
perm_domain1 = BaseExtensions.indexin(contracted, dimnames1)

perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2)
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)

permblocks_dest = (perm_codomain_dest, perm_domain_dest)
biperm_dest = blockedpermvcat(permblocks_dest...)
permblocks1 = (perm_codomain1, perm_domain1)
biperm1 = blockedpermvcat(permblocks1...)
permblocks2 = (perm_codomain2, perm_domain2)
biperm2 = blockedpermvcat(permblocks2...)
return biperm_dest, biperm1, biperm2
return biperm_a12_to_dest, biperm1, biperm2
end
24 changes: 16 additions & 8 deletions src/contract/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ default_contract_alg() = Matricize()
function contract!(
alg::Algorithm,
a_dest::AbstractArray,
biperm_dest::AbstractBlockPermutation,
biperm_a12_to_dest::AbstractBlockPermutation,
a1::AbstractArray,
biperm1::AbstractBlockPermutation,
a2::AbstractArray,
Expand Down Expand Up @@ -89,8 +89,10 @@ function contract(
kwargs...,
)
check_input(contract, a1, labels1, a2, labels2)
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...)
biperm_a12_to_dest, biperm1, biperm2 = blockedperms(
contract, labels_dest, labels1, labels2
)
return contract(alg, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α; kwargs...)
end

function contract!(
Expand All @@ -106,13 +108,17 @@ function contract!(
kwargs...,
)
check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2)
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
biperm_a12_to_dest, biperm1, biperm2 = blockedperms(
contract, labels_dest, labels1, labels2
)
return contract!(
alg, a_dest, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α, β; kwargs...
)
end

function contract(
alg::Algorithm,
biperm_dest::AbstractBlockPermutation,
biperm_a12_to_dest::AbstractBlockPermutation,
a1::AbstractArray,
biperm1::AbstractBlockPermutation,
a2::AbstractArray,
Expand All @@ -121,7 +127,9 @@ function contract(
kwargs...,
)
check_input(contract, a1, biperm1, a2, biperm2)
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
a_dest = allocate_output(contract, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α)
contract!(
alg, a_dest, biperm_a12_to_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...
)
return a_dest
end
10 changes: 5 additions & 5 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ using LinearAlgebra: mul!
function contract!(
::Matricize,
a_dest::AbstractArray,
biperm_dest::AbstractBlockPermutation{2},
biperm_a12_to_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::AbstractBlockPermutation{2},
α::Number,
β::Number,
)
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
a_dest_mat = matricize(a_dest, biperm_dest)
biperm_dest_to_a12 = invbiperm(biperm_a12_to_dest, Val(first(blocklengths(biperm1))))
check_input(contract, a_dest, biperm_dest_to_a12, a1, biperm1, a2, biperm2)
a1_mat = matricize(a1, biperm1)
a2_mat = matricize(a2, biperm2)
mul!(a_dest_mat, a1_mat, a2_mat, α, β)
unmatricize!(a_dest, a_dest_mat, biperm_dest)
a_dest_mat = a1_mat * a2_mat
unmatricize_add!(a_dest, a_dest_mat, biperm_dest_to_a12, α, β)
return a_dest
end
38 changes: 27 additions & 11 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,24 @@ function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple)
end

# ==================================== unmatricize =======================================
function unmatricize(m::AbstractMatrix, axes, biperm::AbstractBlockPermutation{2})
length(axes) == length(biperm) || throw(ArgumentError("axes do not match permutation"))
return unmatricize(FusionStyle(m), m, axes, biperm)
function unmatricize(
m::AbstractMatrix, axes_dest, biperm_dest_to_a12::AbstractBlockPermutation{2}
)
length(axes_dest) == length(biperm_dest_to_a12) ||
throw(ArgumentError("axes do not match permutation"))
return unmatricize(FusionStyle(m), m, axes_dest, biperm_dest_to_a12)
Comment on lines +77 to +82
Copy link
Member

@mtfishman mtfishman Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the context of this function, the name biperm_dest_to_a12 is more confusing than helpful (that name only makes sense in the context of contract but this function could be called for other purposes). I think it should be clear that biperm means the permutation that should be performed on m after it is reinterpreted as a length(axes_dest)-dimensional array (unless I'm misunderstanding the conventions of this function, in which case we should change it to that convention).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing with the name axes_dest, I think axes is clear enough.

end

function unmatricize(
::FusionStyle, m::AbstractMatrix, axes, biperm::AbstractBlockPermutation{2}
::FusionStyle,
m::AbstractMatrix,
axes_dest,
biperm_dest_to_a12::AbstractBlockPermutation{2},
Comment on lines +88 to +89
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments as above about the naming.

)
blocked_axes = axes[biperm]
a_perm = unmatricize(m, blocked_axes)
return permuteblockeddims(a_perm, invperm(biperm))
blocked_axes = axes_dest[biperm_dest_to_a12]
a12 = unmatricize(m, blocked_axes)
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes_dest)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I lost track of the discussions we had about the conventions we want to use in this PR, I thought we had discussed that we would change the convention of unmatricize so that the biperm that gets input would be taken "literally", i.e. it wouldn't need to be inverted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe as an alternative, we could change the convention of unmatricize so that the axes that get input are the unpermuted axes, i.e. the axes corresponding directly to the memory ordering of the input matrix m. Then, the bipermutation that gets input is just the bipermutation that needs to be done to get the desired output. I.e. it would be equivalent to:

function unmatricize(
  style::FusionStyle,
  m::AbstractMatrix,
  ax,
  biperm::AbstractBlockPermutation{2},
)
  a = unmatricize(style, m, ax)
  return permutedims(a, biperm)
end

return permuteblockeddims(a12, biperm_a12_to_dest)
end

function unmatricize(
Expand All @@ -108,10 +115,19 @@ function unmatricize(
return unmatricize(m, blocked_axes)
end

function unmatricize!(a, m::AbstractMatrix, biperm::AbstractBlockPermutation{2})
ndims(a) == length(biperm) ||
function unmatricize!(
a_dest, m::AbstractMatrix, biperm_dest_to_a12::AbstractBlockPermutation{2}
)
ndims(a_dest) == length(biperm_dest_to_a12) ||
throw(ArgumentError("destination does not match permutation"))
blocked_axes = axes(a)[biperm]
blocked_axes = axes(a_dest)[biperm_dest_to_a12]
a_perm = unmatricize(m, blocked_axes)
return permuteblockeddims!(a, a_perm, invperm(biperm))
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes(a_dest))
return permuteblockeddims!(a_dest, a_perm, biperm_a12_to_dest)
end

function unmatricize_add!(a_dest, a_dest_mat, biperm_dest_to_a12, α, β)
a12 = unmatricize(a_dest_mat, axes(a_dest), biperm_dest_to_a12)
a_dest .= α .* a12 .+ β .* a_dest
return a_dest
end
7 changes: 4 additions & 3 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a ≈ a0

bp = blockedpermvcat((4, 2), (1, 3))
a = unmatricize(m, map(i -> axes0[i], invperm(Tuple(bp))), bp)
bpinv = blockedpermvcat((3, 2), (4, 1))
a = unmatricize(m, map(i -> axes0[i], bp), bpinv)
@test eltype(a) === elt
@test a ≈ permutedims(a0, invperm(Tuple(bp)))
@test a ≈ permutedims(a0, Tuple(bp))

a = similar(a0)
unmatricize!(a, m, blockedpermvcat((1, 2), (3, 4)))
Expand All @@ -109,7 +110,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

a1 = permutedims(a0, Tuple(bp))
a = similar(a1)
unmatricize!(a, m, invperm(bp))
unmatricize!(a, m, bpinv)
@test a ≈ a1

a = unmatricize(m, (), axes0)
Expand Down
20 changes: 16 additions & 4 deletions test/test_blockarrays_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,46 +25,58 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

@testset "BlockedArray" begin
# matrix matrix
a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
@test_broken a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
#=
a_dest_dense, dimnames_dest_dense = contract(
a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4)
)
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockedArray{elt}
@test a_dest ≈ a_dest_dense
=#

# matrix vector
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
#=
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockedArray{elt}
@test a_dest ≈ a_dest_dense
=#

# vector matrix
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
#=
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockedArray{elt}
@test a_dest ≈ a_dest_dense
=#

# vector vector
# worse than broken: infinite recursion
@test_broken false
#=
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockedArray{elt,0}
@test a_dest ≈ a_dest_dense
=#

# outer product
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4))
#=
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockedArray{elt}
@test a_dest ≈ a_dest_dense
=#
end

@testset "BlockArray" begin
Expand Down
Loading