Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
matrix:
pkg:
- 'BlockSparseArrays'
- 'FusionTensors'
- 'NamedDimsArrays'
uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main"
with:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.16"
version = "0.3.17"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
22 changes: 12 additions & 10 deletions ext/TensorAlgebraTensorOperationsExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
module TensorAlgebraTensorOperationsExt

using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm
using TupleTools
using TensorOperations
using TensorOperations: AbstractBackend, DefaultBackend
using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm, blocklengths
using TupleTools: TupleTools
using TensorOperations: TensorOperations, AbstractBackend, DefaultBackend, Index2Tuple

"""
TensorOperationsAlgorithm(backend::AbstractBackend)
Expand Down Expand Up @@ -44,8 +43,9 @@ function TensorAlgebra.contract(
pA = _index2tuple(bipermA)
pB = _index2tuple(bipermB)
pAB = _index2tuple(bipermAB)

return tensorcontract(A, pA, false, B, pB, false, pAB, α, algorithm.backend)
return TensorOperations.tensorcontract(
A, pA, false, B, pB, false, pAB, α, algorithm.backend
)
end

function TensorAlgebra.contract(
Expand All @@ -62,7 +62,7 @@ function TensorAlgebra.contract(
end

# in-place
function TensorAlgebra.contract!(
function TensorAlgebra.contractadd!(
algorithm::TensorOperationsAlgorithm,
C::AbstractArray,
bipermAB::BlockedPermutation,
Expand All @@ -76,10 +76,12 @@ function TensorAlgebra.contract!(
pA = _index2tuple(bipermA)
pB = _index2tuple(bipermB)
pAB = _index2tuple(bipermAB)
return tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend)
return TensorOperations.tensorcontract!(
C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend
)
end

function TensorAlgebra.contract!(
function TensorAlgebra.contractadd!(
algorithm::TensorOperationsAlgorithm,
C::AbstractArray,
labelsC,
Expand Down Expand Up @@ -117,7 +119,7 @@ function TensorOperations.tensorcontract!(
bipermAB = _blockedpermutation(pAB)
A′ = conjA ? conj(A) : A
B′ = conjB ? conj(B) : B
return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
return TensorAlgebra.contractadd!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
end

# For now no trace/add is supported, so simply reselect default backend from TensorOperations
Expand Down
6 changes: 2 additions & 4 deletions src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ function output_axes(
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::AbstractBlockPermutation{2},
α::Number=one(Bool),
)
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
Expand All @@ -40,9 +39,8 @@ function allocate_output(
biperm1::AbstractBlockPermutation,
a2::AbstractArray,
biperm2::AbstractBlockPermutation,
α::Number=one(Bool),
)
check_input(contract, a1, biperm1, a2, biperm2)
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2)
return similar(a1, promote_type(eltype(a1), eltype(a2)), axes_dest)
end
71 changes: 45 additions & 26 deletions src/contract/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ default_contract_alg() = Matricize()

# Required interface if not using
# matricized contraction.
function contract!(
function contractadd!(
alg::Algorithm,
a_dest::AbstractArray,
biperm_dest::AbstractBlockPermutation,
Expand All @@ -28,53 +28,59 @@ function contract(
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α::Number=one(Bool);
labels2;
alg=default_contract_alg(),
kwargs...,
)
return contract(Algorithm(alg), a1, labels1, a2, labels2, α; kwargs...)
return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...)
end

function contract(
alg::Algorithm,
alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...
)
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2; kwargs...)
return contract(alg, labels_dest, a1, labels1, a2, labels2; kwargs...), labels_dest
end

function contract(
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α::Number=one(Bool);
labels2;
alg=default_contract_alg(),
kwargs...,
)
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α; kwargs...)
return contract(alg, labels_dest, a1, labels1, a2, labels2, α; kwargs...), labels_dest
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...)
end

function contract(
function contract!(
a_dest::AbstractArray,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α::Number=one(Bool);
alg=default_contract_alg(),
labels2;
kwargs...,
)
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2, α; kwargs...)
return contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...)
end

function contract!(
function contractadd!(
a_dest::AbstractArray,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α::Number=one(Bool),
β::Number=zero(Bool);
α::Number,
β::Number;
alg=default_contract_alg(),
kwargs...,
)
contract!(Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...)
contractadd!(
Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...
)
return a_dest
end

Expand All @@ -84,16 +90,30 @@ function contract(
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α::Number=one(Bool);
labels2;
kwargs...,
)
check_input(contract, a1, labels1, a2, labels2)
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...)
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2; kwargs...)
end

function contract!(
alg::Algorithm,
a_dest::AbstractArray,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2;
kwargs...,
)
return contractadd!(
alg, a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...
)
end

function contractadd!(
alg::Algorithm,
a_dest::AbstractArray,
labels_dest,
Expand All @@ -107,7 +127,7 @@ function contract!(
)
check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2)
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
return contractadd!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
end

function contract(
Expand All @@ -116,12 +136,11 @@ function contract(
a1::AbstractArray,
biperm1::AbstractBlockPermutation,
a2::AbstractArray,
biperm2::AbstractBlockPermutation,
α::Number;
biperm2::AbstractBlockPermutation;
kwargs...,
)
check_input(contract, a1, biperm1, a2, biperm2)
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2)
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2; kwargs...)
return a_dest
end
4 changes: 2 additions & 2 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra: mul!

function contract!(
function contractadd!(
::Matricize,
a_dest::AbstractArray,
biperm_dest::AbstractBlockPermutation{2},
Expand All @@ -17,6 +17,6 @@ function contract!(
a1_mat = matricize(a1, biperm1)
a2_mat = matricize(a2, biperm2)
a_dest_mat = a1_mat * a2_mat
unmatricize_add!(a_dest, a_dest_mat, invbiperm, α, β)
unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β)
return a_dest
end
1 change: 0 additions & 1 deletion src/contract/output_labels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ function output_labels(
labels1,
a2::AbstractArray,
labels2,
α,
)
return output_labels(f, alg, labels1, labels2)
end
Expand Down
2 changes: 1 addition & 1 deletion src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function unmatricize!(a_dest, m::AbstractMatrix, invbiperm::AbstractBlockPermuta
return permuteblockeddims!(a_dest, a_perm, biperm_dest)
end

function unmatricize_add!(a_dest, a_dest_mat, invbiperm, α, β)
function unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β)
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
a_dest .= α .* a12 .+ β .* a_dest
return a_dest
Expand Down
5 changes: 3 additions & 2 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using TensorAlgebra:
blockedpermvcat,
contract,
contract!,
contractadd!,
length_codomain,
length_domain,
matricize,
Expand Down Expand Up @@ -213,9 +214,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
β = elt_dest(2.4) # randn(elt_dest)
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
a_dest = copy(a_dest_init)
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
a_dest_tensoroperations = copy(a_dest_init)
contract!(
contractadd!(
alg_tensoroperations,
a_dest_tensoroperations,
labels_dest,
Expand Down
Loading