Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.12"
version = "0.3.13"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand All @@ -20,6 +20,13 @@ EllipsisNotation = "1.8.0"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2"
TensorProducts = "0.1.5"
TensorOperations = "5"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.2.1, 0.3, 0.4"
julia = "1.10"

[weakdeps]
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

[extensions]
TensorAlgebraTensorOperationsExt = "TensorOperations"
150 changes: 150 additions & 0 deletions ext/TensorAlgebraTensorOperationsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
module TensorAlgebraTensorOperationsExt

using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm
using TupleTools
using TensorOperations
using TensorOperations: AbstractBackend, DefaultBackend

"""
TensorOperationsAlgorithm(backend::AbstractBackend)

Wrapper type for making a TensorOperations backend work as a TensorAlgebra algorithm.
"""
struct TensorOperationsAlgorithm{B<:AbstractBackend} <: Algorithm
backend::B
end

TensorAlgebra.Algorithm(backend::AbstractBackend) = TensorOperationsAlgorithm(backend)

trivtuple(n) = ntuple(identity, n)

function _index2tuple(p::BlockedPermutation{2})
N₁, N₂ = blocklengths(p)
return (
TupleTools.getindices(Tuple(p), trivtuple(N₁)),
TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)),
)
end

_blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...)

# Using TensorOperations backends as TensorAlgebra implementations
# ----------------------------------------------------------------

# not in-place
function TensorAlgebra.contract(
algorithm::TensorOperationsAlgorithm,
bipermAB::BlockedPermutation,
A::AbstractArray,
bipermA::BlockedPermutation,
B::AbstractArray,
bipermB::BlockedPermutation,
α::Number,
)
pA = _index2tuple(bipermA)
pB = _index2tuple(bipermB)
pAB = _index2tuple(bipermAB)

return tensorcontract(A, pA, false, B, pB, false, pAB, α, algorithm.backend)
end

function TensorAlgebra.contract(
algorithm::TensorOperationsAlgorithm,
labelsC,
A::AbstractArray,
labelsA,
B::AbstractArray,
labelsB,
α::Number,
)
pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC)
return tensorcontract(A, pA, false, B, pB, false, pAB, α, algorithm.backend)
end

# in-place
function TensorAlgebra.contract!(
algorithm::TensorOperationsAlgorithm,
C::AbstractArray,
bipermAB::BlockedPermutation,
A::AbstractArray,
bipermA::BlockedPermutation,
B::AbstractArray,
bipermB::BlockedPermutation,
α::Number,
β::Number,
)
pA = _index2tuple(bipermA)
pB = _index2tuple(bipermB)
pAB = _index2tuple(bipermAB)
return tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend)
end

function TensorAlgebra.contract!(
algorithm::TensorOperationsAlgorithm,
C::AbstractArray,
labelsC,
A::AbstractArray,
labelsA,
B::AbstractArray,
labelsB,
α::Number,
β::Number,
)
pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC)
return TensorOperations.tensorcontract!(
C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend
)
end

# Using TensorAlgebra implementations as TensorOperations backends
# ----------------------------------------------------------------
function TensorOperations.tensorcontract!(
C::AbstractArray,
A::AbstractArray,
pA::Index2Tuple,
conjA::Bool,
B::AbstractArray,
pB::Index2Tuple,
conjB::Bool,
pAB::Index2Tuple,
α::Number,
β::Number,
backend::Algorithm,
allocator,
)
bipermA = _blockedpermutation(pA)
bipermB = _blockedpermutation(pB)
bipermAB = _blockedpermutation(pAB)
A′ = conjA ? conj(A) : A
B′ = conjB ? conj(B) : B
return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
end

# For now no trace/add is supported, so simply reselect default backend from TensorOperations
function TensorOperations.tensortrace!(
C::AbstractArray,
A::AbstractArray,
p::Index2Tuple,
q::Index2Tuple,
conjA::Bool,
α::Number,
β::Number,
::Algorithm,
allocator,
)
return TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, DefaultBackend(), allocator)
end
function TensorOperations.tensoradd!(
C::AbstractArray,
A::AbstractArray,
pA::Index2Tuple,
conjA::Bool,
α::Number,
β::Number,
::Algorithm,
allocator,
)
return TensorOperations.tensoradd!(C, A, pA, conjA, α, β, DefaultBackend(), allocator)
end

end
31 changes: 19 additions & 12 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using StableRNGs: StableRNG
using TensorOperations: TensorOperations

using TensorAlgebra:
Algorithm,
BlockedTuple,
blockedpermvcat,
contract,
Expand Down Expand Up @@ -141,7 +142,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test_throws ArgumentError unmatricize!(m, m, blockedpermvcat((1, 2), (3,)))
end

using TensorOperations: TensorOperations
alg_tensoroperations = Algorithm(TensorOperations.StridedBLAS())
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
elt_dest = promote_type(elt1, elt2)
a1 = ones(elt1, (1, 1))
Expand Down Expand Up @@ -184,15 +185,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
@test labels_dest′ isa
BlockedTuple{2,(length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))}
a_dest_tensoroperations = TensorOperations.tensorcontract(
Tuple(labels_dest′), a1, labels1, a2, labels2
)
a_dest_tensoroperations, = contract(alg_tensoroperations, a1, labels1, a2, labels2)
@test a_dest ≈ a_dest_tensoroperations

# Specify destination labels
a_dest = contract(labels_dest, a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
a_dest_tensoroperations = contract(
alg_tensoroperations, labels_dest, a1, labels1, a2, labels2
)
@test a_dest ≈ a_dest_tensoroperations

Expand All @@ -202,8 +201,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2)
@test a_dest ≈ a_dest_tensoroperations
a_dest = contract(labels_dest′, a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
Tuple(labels_dest′), a1, labels1, a2, labels2
a_dest_tensoroperations = contract(
alg_tensoroperations, labels_dest′, a1, labels1, a2, labels2
)
@test a_dest ≈ a_dest_tensoroperations

Expand All @@ -215,13 +214,21 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
a_dest = copy(a_dest_init)
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
a_dest_tensoroperations = copy(a_dest_init)
contract!(
alg_tensoroperations,
a_dest_tensoroperations,
labels_dest,
a1,
labels1,
a2,
labels2,
α,
β,
)
## Here we loosened the tolerance because of some floating point roundoff issue.
## with Float32 numbers
@test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol =
50 * default_rtol(elt_dest)
@test a_dest ≈ a_dest_tensoroperations rtol = 50 * default_rtol(elt_dest)
end
end
@testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
Expand Down
123 changes: 123 additions & 0 deletions test/test_tensoroperations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
using Test: @test, @testset, @inferred
using TensorOperations: @tensor, ncon, tensorcontract
using TensorAlgebra: Matricize

@testset "tensorcontract" begin
A = randn(Float64, (3, 20, 5, 3, 4))
B = randn(Float64, (5, 6, 20, 3))
C1 = @inferred tensorcontract(
A, ((1, 4, 5), (2, 3)), false, B, ((3, 1), (2, 4)), false, ((1, 5, 3, 2, 4), ()), 1.0
)
C2 = @inferred tensorcontract(
A,
((1, 4, 5), (2, 3)),
false,
B,
((3, 1), (2, 4)),
false,
((1, 5, 3, 2, 4), ()),
1.0,
Matricize(),
)
@test C1 ≈ C2
end

elts = (Float32, Float64, ComplexF32, ComplexF64)

@testset "tensor network examples ($T)" for T in elts
D1, D2, D3 = 30, 40, 20
d1, d2 = 2, 3
A1 = rand(T, D1, d1, D2) .- 1//2
A2 = rand(T, D2, d2, D3) .- 1//2
rhoL = rand(T, D1, D1) .- 1//2
rhoR = rand(T, D3, D3) .- 1//2
H = rand(T, d1, d2, d1, d2) .- 1//2

@tensor HrA12[a, s1, s2, c] :=
rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2]
@tensor backend = Matricize() HrA12′[a, s1, s2, c] :=
rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2]

@test HrA12 ≈ HrA12′
@test HrA12 ≈ ncon(
[rhoL, H, A2, rhoR, A1],
[[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]];
backend=Matricize(),
)
E = @tensor rhoL[a', a] *
A1[a, s, b] *
A2[b, s', c] *
rhoR[c, c'] *
H[t, t', s, s'] *
conj(A1[a', t, b']) *
conj(A2[b', t', c'])
@test E ≈ @tensor backend = Matricize() rhoL[a', a] *
A1[a, s, b] *
A2[b, s', c] *
rhoR[c, c'] *
H[t, t', s, s'] *
conj(A1[a', t, b']) *
conj(A2[b', t', c'])
end

function generate_random_network(
num_contracted_inds, num_open_inds, max_dim, max_ind_per_tensor
)
contracted_indices = repeat(collect(1:num_contracted_inds), 2)
open_indices = collect(1:num_open_inds)
dimensions = [
repeat(rand(1:max_dim, num_contracted_inds), 2)
rand(1:max_dim, num_open_inds)
]

sizes = Vector{Int64}[]
indices = Vector{Int64}[]

while !isempty(contracted_indices) || !isempty(open_indices)
num_inds = rand(
1:min(max_ind_per_tensor, length(contracted_indices) + length(open_indices))
)

cur_inds = Int64[]
cur_dims = Int64[]

for _ in 1:num_inds
curind_index = rand(1:(length(contracted_indices) + length(open_indices)))

if curind_index <= length(contracted_indices)
push!(cur_inds, contracted_indices[curind_index])
push!(cur_dims, dimensions[curind_index])
deleteat!(contracted_indices, curind_index)
deleteat!(dimensions, curind_index)
else
tind = curind_index - length(contracted_indices)
push!(cur_inds, -open_indices[tind])
push!(cur_dims, dimensions[curind_index])
deleteat!(open_indices, tind)
deleteat!(dimensions, curind_index)
end
end

push!(sizes, cur_dims)
push!(indices, cur_inds)
end
return sizes, indices
end

@testset "random contractions" begin
MAX_CONTRACTED_INDICES = 10
MAX_OPEN_INDICES = 5
MAX_DIM = 5
MAX_IND_PER_TENS = 3
NUM_TESTS = 10

for _ in 1:NUM_TESTS
sizes, indices = generate_random_network(
rand(1:MAX_CONTRACTED_INDICES), rand(1:MAX_OPEN_INDICES), MAX_DIM, MAX_IND_PER_TENS
)
tensors = map(splat(randn), sizes)
result1 = ncon(tensors, indices)
result2 = ncon(tensors, indices; backend=Matricize())
@test result1 ≈ result2
end
end
Loading