Skip to content

Commit f1cce56

Browse files
authored
Contraction sequence finding (#218)
1 parent 62d30e0 commit f1cce56

17 files changed

+67
-19
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.11.27"
4+
version = "0.12.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -41,13 +41,15 @@ EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
4141
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
4242
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
4343
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
44+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
4445

4546
[extensions]
4647
ITensorNetworksAdaptExt = "Adapt"
4748
ITensorNetworksEinExprsExt = "EinExprs"
4849
ITensorNetworksGraphsFlowsExt = "GraphsFlows"
4950
ITensorNetworksOMEinsumContractionOrdersExt = "OMEinsumContractionOrders"
5051
ITensorNetworksObserversExt = "Observers"
52+
ITensorNetworksTensorOperationsExt = "TensorOperations"
5153

5254
[compat]
5355
AbstractTrees = "0.4.4"
@@ -80,6 +82,7 @@ SplitApplyCombine = "1.2"
8082
StaticArrays = "1.5.12"
8183
StructWalk = "0.2"
8284
Suppressor = "0.2"
85+
TensorOperations = "5.1.4"
8386
TimerOutputs = "0.5.22"
8487
TupleTools = "1.4"
8588
julia = "1.10"
@@ -90,6 +93,7 @@ EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
9093
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
9194
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
9295
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
96+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
9397
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9498

9599
[targets]

ext/ITensorNetworksEinExprsExt/ITensorNetworksEinExprsExt.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@ module ITensorNetworksEinExprsExt
22

33
using ITensors: Index, ITensor, @Algorithm_str, inds, noncommoninds
44
using ITensorNetworks:
5-
ITensorNetworks, ITensorNetwork, vertextype, vertex_data, contraction_sequence
5+
ITensorNetworks,
6+
ITensorList,
7+
ITensorNetwork,
8+
vertextype,
9+
vertex_data,
10+
contraction_sequence
611
using EinExprs: EinExprs, EinExpr, einexpr, SizedEinExpr
712

8-
function to_einexpr(ts::Vector{ITensor})
13+
function to_einexpr(ts::ITensorList)
914
IndexType = Any
1015

1116
tensor_exprs = EinExpr{IndexType}[]
@@ -21,7 +26,7 @@ function to_einexpr(ts::Vector{ITensor})
2126
return SizedEinExpr(sum(tensor_exprs; skip=externalinds_tn), inds_dims)
2227
end
2328

24-
function tensor_inds_to_vertex(ts::Vector{ITensor})
29+
function tensor_inds_to_vertex(ts::ITensorList)
2530
IndexType = Any
2631
VertexType = Int
2732

@@ -36,7 +41,7 @@ function tensor_inds_to_vertex(ts::Vector{ITensor})
3641
end
3742

3843
function ITensorNetworks.contraction_sequence(
39-
::Algorithm"einexpr", tn::Vector{ITensor}; optimizer=EinExprs.Exhaustive()
44+
::Algorithm"einexpr", tn::ITensorList; optimizer=EinExprs.Exhaustive()
4045
)
4146
expr = to_einexpr(tn)
4247
path = einexpr(optimizer, expr)

ext/ITensorNetworksOMEinsumContractionOrdersExt/ITensorNetworksOMEinsumContractionOrdersExt.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module ITensorNetworksOMEinsumContractionOrdersExt
22
using DocStringExtensions: TYPEDSIGNATURES
3-
using ITensorNetworks: ITensorNetworks
3+
using ITensorNetworks: ITensorNetworks, ITensorList
44
using ITensors: ITensors, Index, ITensor, inds
55
using NDTensors: dim
66
using NDTensors.AlgorithmSelection: @Algorithm_str
@@ -9,8 +9,6 @@ using OMEinsumContractionOrders: OMEinsumContractionOrders
99
# OMEinsumContractionOrders wrapper for ITensors
1010
# Slicing is not supported, because it might require extra work to slice an `ITensor` correctly.
1111

12-
const ITensorList = Union{Vector{ITensor},Tuple{Vararg{ITensor}}}
13-
1412
# infer the output tensor labels
1513
# TODO: Use `symdiff` instead.
1614
function infer_output(inputs::AbstractVector{<:AbstractVector{<:Index}})
@@ -126,7 +124,9 @@ Optimize the einsum contraction pattern using the simulated annealing on tensor
126124
### References
127125
* [Recursive Multi-Tensor Contraction for XEB Verification of Quantum Circuits](https://arxiv.org/abs/2108.05665)
128126
"""
129-
function ITensorNetworks.contraction_sequence(::Algorithm"tree_sa", tn; kwargs...)
127+
function ITensorNetworks.contraction_sequence(
128+
::Algorithm"tree_sa", tn::ITensorList; kwargs...
129+
)
130130
return optimize_contraction_sequence(
131131
tn; optimizer=OMEinsumContractionOrders.TreeSA(; kwargs...)
132132
)
@@ -153,7 +153,9 @@ Then finds the contraction order inside each group with the greedy search algori
153153
### References
154154
* [Hyper-optimized tensor network contraction](https://arxiv.org/abs/2002.01935)
155155
"""
156-
function ITensorNetworks.contraction_sequence(::Algorithm"sa_bipartite", tn; kwargs...)
156+
function ITensorNetworks.contraction_sequence(
157+
::Algorithm"sa_bipartite", tn::ITensorList; kwargs...
158+
)
157159
return optimize_contraction_sequence(
158160
tn; optimizer=OMEinsumContractionOrders.SABipartite(; kwargs...)
159161
)
@@ -177,7 +179,9 @@ Then finds the contraction order inside each group with the greedy search algori
177179
* [Hyper-optimized tensor network contraction](https://arxiv.org/abs/2002.01935)
178180
* [Simulating the Sycamore quantum supremacy circuits](https://arxiv.org/abs/2103.03074)
179181
"""
180-
function ITensorNetworks.contraction_sequence(::Algorithm"kahypar_bipartite", tn; kwargs...)
182+
function ITensorNetworks.contraction_sequence(
183+
::Algorithm"kahypar_bipartite", tn::ITensorList; kwargs...
184+
)
181185
return optimize_contraction_sequence(
182186
tn; optimizer=OMEinsumContractionOrders.KaHyParBipartite(; kwargs...)
183187
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module ITensorNetworksTensorOperationsExt
2+
3+
using ITensors: ITensors, ITensor, dim, inds
4+
using ITensorNetworks: ITensorNetworks, ITensorList
5+
using NDTensors.AlgorithmSelection: @Algorithm_str
6+
using TensorOperations: TensorOperations, optimaltree
7+
8+
function ITensorNetworks.contraction_sequence(::Algorithm"optimal", tn::ITensorList)
9+
network = collect.(inds.(tn))
10+
inds_to_dims = Dict(i => dim(i) for i in unique(reduce(vcat, network)))
11+
seq, _ = optimaltree(network, inds_to_dims)
12+
return seq
13+
end
14+
15+
end

src/contract.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using ITensors: ITensor, scalar
2-
using ITensors.ContractionSequenceOptimization: deepmap
32
using ITensors.NDTensors: NDTensors, Algorithm, @Algorithm_str, contract
43
using LinearAlgebra: normalize!
54
using NamedGraphs: NamedGraphs

src/contraction_sequences.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
using Graphs: vertices
2-
using ITensors: ITensor, contract
3-
using ITensors.ContractionSequenceOptimization: deepmap, optimal_contraction_sequence
2+
using ITensors: ITensor
43
using ITensors.NDTensors: Algorithm, @Algorithm_str
54
using NamedGraphs.Keys: Key
65
using NamedGraphs.OrdinalIndexing: th
76

8-
function contraction_sequence(tn::Vector{ITensor}; alg="optimal", kwargs...)
7+
const ITensorList = Union{Vector{ITensor},Tuple{Vararg{ITensor}}}
8+
9+
function contraction_sequence(tn::ITensorList; alg="optimal", kwargs...)
910
return contraction_sequence(Algorithm(alg), tn; kwargs...)
1011
end
1112

13+
function contraction_sequence(alg::Algorithm, tn::ITensorList)
14+
return throw(
15+
ArgumentError(
16+
"Algorithm $alg isn't defined for contraction sequence finding. Try loading a backend package like
17+
TensorOperations.jl or OMEinsumContractionOrders.jl.",
18+
),
19+
)
20+
end
21+
22+
function deepmap(f, tree; filter=(x -> x isa AbstractArray))
23+
return filter(tree) ? map(t -> deepmap(f, t; filter=filter), tree) : f(tree)
24+
end
25+
1226
function contraction_sequence(tn::AbstractITensorNetwork; kwargs...)
1327
# TODO: Use `token_vertex` and/or `token_vertices` here.
1428
ts = map(v -> tn[v], (1:nv(tn))th)
1529
seq_linear_index = contraction_sequence(ts; kwargs...)
1630
# TODO: Use `Functors.fmap` or `StructWalk`?
1731
return deepmap(n -> Key(vertices(tn)[n * th]), seq_linear_index)
1832
end
19-
20-
function contraction_sequence(::Algorithm"optimal", tn::Vector{ITensor})
21-
return optimal_contraction_sequence(tn)
22-
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2727
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
2828
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2929
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
30+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
3031
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3132
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
3233
Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9"

test/test_additensornetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ITensorNetworks: ITensorNetwork, inner_network, random_tensornetwork, site
66
using ITensors: ITensors, apply, op, scalar, inner
77
using LinearAlgebra: norm_sqr
88
using StableRNGs: StableRNG
9+
using TensorOperations: TensorOperations
910
using Test: @test, @testset
1011
@testset "add_itensornetworks" begin
1112
g = named_grid((2, 2))

test/test_apply.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using NamedGraphs.NamedGraphGenerators: named_grid
1616
using NamedGraphs.PartitionedGraphs: PartitionVertex
1717
using SplitApplyCombine: group
1818
using StableRNGs: StableRNG
19+
using TensorOperations: TensorOperations
1920
using Test: @test, @testset
2021
@testset "apply" begin
2122
g_dims = (2, 2)

test/test_belief_propagation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid
3434
using NamedGraphs.PartitionedGraphs: PartitionVertex, partitionedges
3535
using SplitApplyCombine: group
3636
using StableRNGs: StableRNG
37+
using TensorOperations: TensorOperations
3738
using Test: @test, @testset
3839

3940
@testset "belief_propagation (eltype=$elt)" for elt in (

0 commit comments

Comments
 (0)