diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index cdf36ec..fdafe6c 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -18,6 +18,7 @@ jobs: matrix: pkg: - 'BlockSparseArrays' + - 'FusionTensors' - 'NamedDimsArrays' uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main" with: diff --git a/Project.toml b/Project.toml index 79ae9d2..c0b41e7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.16" +version = "0.4.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/docs/Project.toml b/docs/Project.toml index 661cea0..db836a0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [compat] Documenter = "1.8.1" Literate = "2.20.1" -TensorAlgebra = "0.3.0" +TensorAlgebra = "0.4.0" diff --git a/examples/Project.toml b/examples/Project.toml index 0c257a8..bb98e64 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [compat] -TensorAlgebra = "0.3.0" +TensorAlgebra = "0.4.0" diff --git a/ext/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt.jl index 2b25095..2edf43a 100644 --- a/ext/TensorAlgebraTensorOperationsExt.jl +++ b/ext/TensorAlgebraTensorOperationsExt.jl @@ -1,9 +1,8 @@ module TensorAlgebraTensorOperationsExt -using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm -using TupleTools -using TensorOperations -using TensorOperations: AbstractBackend, DefaultBackend +using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm, blocklengths +using TupleTools: TupleTools +using TensorOperations: TensorOperations, AbstractBackend, DefaultBackend, Index2Tuple """ TensorOperationsAlgorithm(backend::AbstractBackend) @@ -44,8 +43,9 @@ function TensorAlgebra.contract( pA = _index2tuple(bipermA) pB = _index2tuple(bipermB) pAB = _index2tuple(bipermAB) - - return tensorcontract(A, pA, false, B, pB, false, pAB, α, algorithm.backend) + return TensorOperations.tensorcontract( + A, pA, false, B, pB, false, pAB, α, algorithm.backend + ) end function TensorAlgebra.contract( @@ -62,7 +62,7 @@ function TensorAlgebra.contract( end # in-place -function TensorAlgebra.contract!( +function TensorAlgebra.contractadd!( algorithm::TensorOperationsAlgorithm, C::AbstractArray, bipermAB::BlockedPermutation, @@ -76,10 +76,12 @@ function TensorAlgebra.contract!( pA = _index2tuple(bipermA) pB = _index2tuple(bipermB) pAB = _index2tuple(bipermAB) - return tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend) + return TensorOperations.tensorcontract!( + C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend + ) end -function TensorAlgebra.contract!( +function TensorAlgebra.contractadd!( algorithm::TensorOperationsAlgorithm, C::AbstractArray, labelsC, @@ -117,7 +119,7 @@ function TensorOperations.tensorcontract!( bipermAB = _blockedpermutation(pAB) A′ = conjA ? conj(A) : A B′ = conjB ? conj(B) : B - return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β) + return TensorAlgebra.contractadd!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β) end # For now no trace/add is supported, so simply reselect default backend from TensorOperations diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index 7284ffa..f1b311d 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -22,7 +22,6 @@ function output_axes( biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, biperm2::AbstractBlockPermutation{2}, - α::Number=one(Bool), ) axes_codomain, axes_contracted = blocks(axes(a1)[biperm1]) axes_contracted2, axes_domain = blocks(axes(a2)[biperm2]) @@ -40,9 +39,8 @@ function allocate_output( biperm1::AbstractBlockPermutation, a2::AbstractArray, biperm2::AbstractBlockPermutation, - α::Number=one(Bool), ) check_input(contract, a1, biperm1, a2, biperm2) - axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α) - return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest) + axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2) + return similar(a1, promote_type(eltype(a1), eltype(a2)), axes_dest) end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index 02665ca..abdcd85 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -10,7 +10,7 @@ default_contract_alg() = Matricize() # Required interface if not using # matricized contraction. -function contract!( +function contractadd!( alg::Algorithm, a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation, @@ -28,53 +28,59 @@ function contract( a1::AbstractArray, labels1, a2::AbstractArray, - labels2, - α::Number=one(Bool); + labels2; alg=default_contract_alg(), kwargs..., ) - return contract(Algorithm(alg), a1, labels1, a2, labels2, α; kwargs...) + return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...) end function contract( - alg::Algorithm, + alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs... +) + labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2; kwargs...) + return contract(alg, labels_dest, a1, labels1, a2, labels2; kwargs...), labels_dest +end + +function contract( + labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, - labels2, - α::Number=one(Bool); + labels2; + alg=default_contract_alg(), kwargs..., ) - labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α; kwargs...) - return contract(alg, labels_dest, a1, labels1, a2, labels2, α; kwargs...), labels_dest + return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...) end -function contract( +function contract!( + a_dest::AbstractArray, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, - labels2, - α::Number=one(Bool); - alg=default_contract_alg(), + labels2; kwargs..., ) - return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2, α; kwargs...) + return contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...) end -function contract!( +function contractadd!( a_dest::AbstractArray, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, - α::Number=one(Bool), - β::Number=zero(Bool); + α::Number, + β::Number; alg=default_contract_alg(), kwargs..., ) - contract!(Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...) + contractadd!( + Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs... + ) return a_dest end @@ -84,16 +90,30 @@ function contract( a1::AbstractArray, labels1, a2::AbstractArray, - labels2, - α::Number=one(Bool); + labels2; kwargs..., ) check_input(contract, a1, labels1, a2, labels2) biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) - return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...) + return contract(alg, biperm_dest, a1, biperm1, a2, biperm2; kwargs...) end function contract!( + alg::Algorithm, + a_dest::AbstractArray, + labels_dest, + a1::AbstractArray, + labels1, + a2::AbstractArray, + labels2; + kwargs..., +) + return contractadd!( + alg, a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs... + ) +end + +function contractadd!( alg::Algorithm, a_dest::AbstractArray, labels_dest, @@ -107,7 +127,7 @@ function contract!( ) check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2) biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) - return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...) + return contractadd!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...) end function contract( @@ -116,12 +136,11 @@ function contract( a1::AbstractArray, biperm1::AbstractBlockPermutation, a2::AbstractArray, - biperm2::AbstractBlockPermutation, - α::Number; + biperm2::AbstractBlockPermutation; kwargs..., ) check_input(contract, a1, biperm1, a2, biperm2) - a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α) - contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...) + a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2) + contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2; kwargs...) return a_dest end diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 77db7e8..be4f26c 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -1,6 +1,6 @@ using LinearAlgebra: mul! -function contract!( +function contractadd!( ::Matricize, a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation{2}, @@ -17,6 +17,6 @@ function contract!( a1_mat = matricize(a1, biperm1) a2_mat = matricize(a2, biperm2) a_dest_mat = a1_mat * a2_mat - unmatricize_add!(a_dest, a_dest_mat, invbiperm, α, β) + unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β) return a_dest end diff --git a/src/contract/output_labels.jl b/src/contract/output_labels.jl index 68525c5..c9c18bf 100644 --- a/src/contract/output_labels.jl +++ b/src/contract/output_labels.jl @@ -5,7 +5,6 @@ function output_labels( labels1, a2::AbstractArray, labels2, - α, ) return output_labels(f, alg, labels1, labels2) end diff --git a/src/matricize.jl b/src/matricize.jl index c67ff9a..d812d92 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -122,7 +122,7 @@ function unmatricize!(a_dest, m::AbstractMatrix, invbiperm::AbstractBlockPermuta return permuteblockeddims!(a_dest, a_perm, biperm_dest) end -function unmatricize_add!(a_dest, a_dest_mat, invbiperm, α, β) +function unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β) a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm) a_dest .= α .* a12 .+ β .* a_dest return a_dest diff --git a/test/Project.toml b/test/Project.toml index c6f8f73..656853a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,7 +23,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.3.0" +TensorAlgebra = "0.4.0" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" diff --git a/test/test_basics.jl b/test/test_basics.jl index 26fe98b..3a75cec 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -10,6 +10,7 @@ using TensorAlgebra: blockedpermvcat, contract, contract!, + contractadd!, length_codomain, length_domain, matricize, @@ -213,9 +214,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) β = elt_dest(2.4) # randn(elt_dest) a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests)) a_dest = copy(a_dest_init) - contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) + contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) a_dest_tensoroperations = copy(a_dest_init) - contract!( + contractadd!( alg_tensoroperations, a_dest_tensoroperations, labels_dest,