Skip to content

Commit 02d33db

Browse files
committed
Working ContractNetwork
1 parent 8269686 commit 02d33db

File tree

5 files changed

+85
-1
lines changed

5 files changed

+85
-1
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,35 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
99
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
1010
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1111
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
12+
ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1415
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
1516
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
1617
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
1718
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
19+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
1820

1921
[compat]
2022
Adapt = "4.3.0"
2123
BackendSelection = "0.1.6"
2224
DataGraphs = "0.2.7"
2325
Dictionaries = "0.4.5"
2426
Graphs = "1.13.1"
27+
ITensorBase = "0.2.13"
2528
LinearAlgebra = "1.10"
2629
MacroTools = "0.5.16"
2730
NamedDimsArrays = "0.7.13"
2831
NamedGraphs = "0.6.9"
2932
SimpleTraits = "0.9.5"
3033
SplitApplyCombine = "1.2.3"
34+
TensorOperations = "5.3.1"
3135
julia = "1.10"

src/ITensorNetworksNext.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ module ITensorNetworksNext
22

33
include("abstracttensornetwork.jl")
44
include("tensornetwork.jl")
5+
include("contractnetwork.jl")
56

67
end

src/contractnetwork.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using TensorOperations: TensorOperations, optimaltree
2+
using ITensorBase: inds, dim
3+
4+
default_sequence_alg = "optimal"
5+
6+
function contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray})
7+
network = collect.(inds.(tn))
8+
#Converting dims to Float64 to minimize overflow issues
9+
inds_to_dims = Dict(i => Float64(dim(i)) for i in unique(reduce(vcat, network)))
10+
seq, _ = optimaltree(network, inds_to_dims)
11+
return seq
12+
end
13+
14+
function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray})
15+
return Any[i for i in 1:length(tn)]
16+
end
17+
18+
function contraction_sequence(tn::Vector{<:AbstractArray}; alg=default_sequence_alg)
19+
contraction_sequence(Algorithm(alg), tn)
20+
end
21+
22+
# Internal recursive worker
23+
function recursive_contractnetwork(tn::Union{AbstractVector,AbstractArray})
24+
tn isa AbstractVector && return reduce(*, map(recursive_contractnetwork, tn))
25+
return tn
26+
end
27+
28+
# Recursive worker for ordering the tensors according to the sequence
29+
rearrange(tn::Vector{<:AbstractArray}, i::Integer) = tn[i]
30+
rearrange(tn::Vector{<:AbstractArray}, v::AbstractVector) = [rearrange(tn, s) for s in v]
31+
32+
function contractnetwork(tn::Vector{<:AbstractArray}; sequence_alg=default_sequence_alg)
33+
sequence = contraction_sequence(tn; alg=sequence_alg)
34+
return recursive_contractnetwork(rearrange(tn, sequence))
35+
end
36+
37+
function contractnetwork(tn::AbstractTensorNetwork; sequence_alg=default_sequence_alg)
38+
return contractnetwork([tn[v] for v in vertices(tn)]; sequence_alg)
39+
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
88
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
99
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1010
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
11+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
1112
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1213

1314
[compat]
@@ -20,4 +21,5 @@ NamedDimsArrays = "0.7.14"
2021
NamedGraphs = "0.6.8"
2122
SafeTestsets = "0.1"
2223
Suppressor = "0.2.8"
24+
TensorOperations = "5.3.1"
2325
Test = "1.10"

test/test_contractnetwork.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using Graphs: edges
2+
using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges
3+
using NamedGraphs.NamedGraphGenerators: named_grid
4+
using ITensorBase: Index, ITensor
5+
using ITensorNetworksNext:
6+
TensorNetwork, linkinds, siteinds, contractnetwork, contraction_sequence
7+
using Test: @test, @testset
8+
9+
@testset "ContractNetwork" begin
10+
@testset "Contract Vectors of ITensors" begin
11+
i, j, k = Index(2), Index(2), Index(5)
12+
A = ITensor([1.0 1.0; 0.5 1.0], i, j)
13+
B = ITensor([2.0, 1.0], i)
14+
C = ITensor([5.0, 1.0], j)
15+
D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k)
16+
17+
ABCD_1 = contractnetwork([A, B, C, D]; sequence_alg="leftassociative")
18+
ABCD_2 = contractnetwork([A, B, C, D]; sequence_alg="optimal")
19+
20+
@test ABCD_1 == ABCD_2
21+
end
22+
23+
@testset "Contract One Dimensional Network" begin
24+
dims = (4, 4)
25+
g = named_grid(dims)
26+
l = Dict(e => Index(2) for e in edges(g))
27+
l = merge(l, Dict(reverse(e) => l[e] for e in edges(g)))
28+
tn = TensorNetwork(g) do v
29+
is = map(e -> l[e], incident_edges(g, v))
30+
return randn(Tuple(is))
31+
end
32+
33+
z1 = contractnetwork(tn; sequence_alg="optimal")[]
34+
z2 = contractnetwork(tn; sequence_alg="leftassociative")[]
35+
36+
@test abs(z1 - z2) / abs(z1) <= 1e-14
37+
end
38+
end

0 commit comments

Comments
 (0)