Skip to content

Commit 5dcecfb

Browse files
committed
Contract Network Stuff
1 parent 4d43eb5 commit 5dcecfb

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

src/contractnetwork.jl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
1-
using TensorOperations: TensorOperations, optimaltree
21
using ITensorBase: inds, dim
32

4-
default_sequence_alg = "optimal"
5-
6-
function contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray})
7-
network = collect.(inds.(tn))
8-
#Converting dims to Float64 to minimize overflow issues
9-
inds_to_dims = Dict(i => Float64(dim(i)) for i in unique(reduce(vcat, network)))
10-
seq, _ = optimaltree(network, inds_to_dims)
11-
return seq
12-
end
3+
default_sequence_alg = "leftassociative"
134

145
function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray})
156
return Any[i for i in 1:length(tn)]
@@ -20,20 +11,20 @@ function contraction_sequence(tn::Vector{<:AbstractArray}; alg=default_sequence_
2011
end
2112

2213
# Internal recursive worker
23-
function recursive_contractnetwork(tn::Union{AbstractVector,AbstractArray})
24-
tn isa AbstractVector && return reduce(*, map(recursive_contractnetwork, tn))
14+
function recursive_contractnetwork(tn::Union{AbstractVector,AbstractNamedDimsArray})
15+
tn isa AbstractVector && return prod(recursive_contractnetwork, tn)
2516
return tn
2617
end
2718

2819
# Recursive worker for ordering the tensors according to the sequence
2920
rearrange(tn::Vector{<:AbstractArray}, i::Integer) = tn[i]
3021
rearrange(tn::Vector{<:AbstractArray}, v::AbstractVector) = [rearrange(tn, s) for s in v]
3122

32-
function contractnetwork(tn::Vector{<:AbstractArray}; sequence_alg=default_sequence_alg)
33-
sequence = contraction_sequence(tn; alg=sequence_alg)
34-
return recursive_contractnetwork(rearrange(tn, sequence))
23+
function contractnetwork(tn::Vector{<:AbstractArray}; sequence=default_sequence_alg)
24+
contract_sequence = isa(sequence, String) ? contraction_sequence(tn; alg=sequence) : sequence
25+
return recursive_contractnetwork(rearrange(tn, contract_sequence))
3526
end
3627

37-
function contractnetwork(tn::AbstractTensorNetwork; sequence_alg=default_sequence_alg)
38-
return contractnetwork([tn[v] for v in vertices(tn)]; sequence_alg)
28+
function contractnetwork(tn::AbstractTensorNetwork; sequence=default_sequence_alg)
29+
return contractnetwork([tn[v] for v in vertices(tn)]; sequence)
3930
end

test/test_contractnetwork.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using NamedGraphs.NamedGraphGenerators: named_grid
44
using ITensorBase: Index, ITensor
55
using ITensorNetworksNext:
66
TensorNetwork, linkinds, siteinds, contractnetwork, contraction_sequence
7+
using TensorOperations
78
using Test: @test, @testset
89

910
@testset "ContractNetwork" begin
@@ -14,8 +15,8 @@ using Test: @test, @testset
1415
C = ITensor([5.0, 1.0], j)
1516
D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k)
1617

17-
ABCD_1 = contractnetwork([A, B, C, D]; sequence_alg="leftassociative")
18-
ABCD_2 = contractnetwork([A, B, C, D]; sequence_alg="optimal")
18+
ABCD_1 = contractnetwork([A, B, C, D]; sequence="leftassociative")
19+
ABCD_2 = contractnetwork([A, B, C, D]; sequence="optimal")
1920

2021
@test ABCD_1 == ABCD_2
2122
end
@@ -30,8 +31,8 @@ using Test: @test, @testset
3031
return randn(Tuple(is))
3132
end
3233

33-
z1 = contractnetwork(tn; sequence_alg="optimal")[]
34-
z2 = contractnetwork(tn; sequence_alg="leftassociative")[]
34+
z1 = contractnetwork(tn; sequence="optimal")[]
35+
z2 = contractnetwork(tn; sequence="leftassociative")[]
3536

3637
@test abs(z1 - z2) / abs(z1) <= 1e3*eps(Float64)
3738
end

0 commit comments

Comments
 (0)