Skip to content

Commit c4085f7

Browse files
JoeyT1994pre-commit-ci[bot]mtfishman
authored
Working contract_network (#7)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matt Fishman <[email protected]>
1 parent 2819116 commit c4085f7

File tree

6 files changed

+113
-1
lines changed

6 files changed

+113
-1
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.10"
4+
version = "0.1.11"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -20,6 +20,12 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2020
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2121
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
2222

23+
[weakdeps]
24+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
25+
26+
[extensions]
27+
ITensorNetworksNextTensorOperationsExt = "TensorOperations"
28+
2329
[compat]
2430
AbstractTrees = "0.4.5"
2531
Adapt = "4.3"
@@ -33,6 +39,7 @@ NamedDimsArrays = "0.8"
3339
NamedGraphs = "0.6.9, 0.7"
3440
SimpleTraits = "0.9.5"
3541
SplitApplyCombine = "1.2.3"
42+
TensorOperations = "5.3.1"
3643
TermInterface = "2"
3744
TypeParameterAccessors = "0.4.4"
3845
WrappedUnions = "0.3"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
module ITensorNetworksNextTensorOperationsExt
2+
3+
using BackendSelection: @Algorithm_str, Algorithm
4+
using NamedDimsArrays: inds
5+
using ITensorNetworksNext: ITensorNetworksNext, contraction_sequence_to_expr
6+
using TensorOperations: TensorOperations, optimaltree
7+
8+
function ITensorNetworksNext.contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray})
9+
network = collect.(inds.(tn))
10+
#Converting dims to Float64 to minimize overflow issues
11+
inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network)))
12+
seq, _ = optimaltree(network, inds_to_dims)
13+
return contraction_sequence_to_expr(seq)
14+
end
15+
16+
end

src/ITensorNetworksNext.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ module ITensorNetworksNext
33
include("lazynameddimsarrays.jl")
44
include("abstracttensornetwork.jl")
55
include("tensornetwork.jl")
6+
include("contract_network.jl")
67

78
end

src/contract_network.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using BackendSelection: @Algorithm_str, Algorithm
2+
using ITensorNetworksNext.LazyNamedDimsArrays: substitute, materialize, lazy,
3+
symnameddims
4+
5+
#Algorithmic defaults
6+
default_sequence_alg(::Algorithm"exact") = "leftassociative"
7+
default_sequence(::Algorithm"exact") = nothing
8+
function set_default_kwargs(alg::Algorithm"exact")
9+
sequence = get(alg, :sequence, nothing)
10+
sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg))
11+
return Algorithm("exact"; sequence, sequence_alg)
12+
end
13+
14+
function contraction_sequence_to_expr(seq)
15+
if seq isa AbstractVector
16+
return prod(contraction_sequence_to_expr, seq)
17+
else
18+
return symnameddims(seq)
19+
end
20+
end
21+
22+
function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray})
23+
return prod(symnameddims, 1:length(tn))
24+
end
25+
26+
function contraction_sequence(tn::Vector{<:AbstractArray}; sequence_alg = default_sequence_alg(Algorithm("exact")))
27+
return contraction_sequence(Algorithm(sequence_alg), tn)
28+
end
29+
30+
function contract_network(alg::Algorithm"exact", tn::Vector{<:AbstractArray})
31+
if !isnothing(alg.sequence)
32+
sequence = alg.sequence
33+
else
34+
sequence = contraction_sequence(tn; sequence_alg = alg.sequence_alg)
35+
end
36+
37+
sequence = substitute(sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn)))
38+
return materialize(sequence)
39+
end
40+
41+
function contract_network(alg::Algorithm"exact", tn::AbstractTensorNetwork)
42+
return contract_network(alg, [tn[v] for v in vertices(tn)])
43+
end
44+
45+
function contract_network(tn; alg, kwargs...)
46+
return contract_network(set_default_kwargs(Algorithm(alg; kwargs...)), tn)
47+
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
99
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
1010
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1111
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
12+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
1213
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
1314
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1415
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
@@ -25,5 +26,6 @@ NamedGraphs = "0.6.8, 0.7"
2526
SafeTestsets = "0.1"
2627
Suppressor = "0.2.8"
2728
TermInterface = "2"
29+
TensorOperations = "5.3.1"
2830
Test = "1.10"
2931
WrappedUnions = "0.3"

test/test_contract_network.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using Graphs: edges
2+
using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges
3+
using NamedGraphs.NamedGraphGenerators: named_grid
4+
using ITensorBase: Index, ITensor
5+
using ITensorNetworksNext:
6+
TensorNetwork, linkinds, siteinds, contract_network
7+
using TensorOperations: TensorOperations
8+
using Test: @test, @testset
9+
10+
@testset "contract_network" begin
11+
@testset "Contract Vectors of ITensors" begin
12+
i, j, k = Index(2), Index(2), Index(5)
13+
A = ITensor([1.0 1.0; 0.5 1.0], i, j)
14+
B = ITensor([2.0, 1.0], i)
15+
C = ITensor([5.0, 1.0], j)
16+
D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k)
17+
18+
ABCD_1 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "leftassociative")
19+
ABCD_2 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "optimal")
20+
21+
@test ABCD_1 == ABCD_2
22+
end
23+
24+
@testset "Contract One Dimensional Network" begin
25+
dims = (4, 4)
26+
g = named_grid(dims)
27+
l = Dict(e => Index(2) for e in edges(g))
28+
l = merge(l, Dict(reverse(e) => l[e] for e in edges(g)))
29+
tn = TensorNetwork(g) do v
30+
is = map(e -> l[e], incident_edges(g, v))
31+
return randn(Tuple(is))
32+
end
33+
34+
z1 = contract_network(tn; alg = "exact", sequence_alg = "optimal")[]
35+
z2 = contract_network(tn; alg = "exact", sequence_alg = "leftassociative")[]
36+
37+
@test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64)
38+
end
39+
end

0 commit comments

Comments
 (0)