@@ -2,22 +2,25 @@ module ITensorNetworksNextTensorOperationsExt
22
33using BackendSelection: @Algorithm_str , Algorithm
44using ITensorNetworksNext: ITensorNetworksNext, contraction_order
5+ using ITensorNetworksNext. LazyNamedDimsArrays: symnameddims, substitute
56using NamedDimsArrays: inds
67using TensorOperations: TensorOperations, optimaltree
78
8- function contraction_order_to_expr (seq )
9- return seq isa AbstractVector ? prod (contraction_order_to_expr, seq ) : symnameddims (seq )
9+ function contraction_order_to_expr (ord )
10+ return ord isa AbstractVector ? prod (contraction_order_to_expr, ord ) : symnameddims (ord )
1011end
1112
1213function ITensorNetworksNext. contraction_order (alg:: Algorithm"optimal" , tn)
13- ts = [tn[i] for i in eachindex (tn)]
14+ ts = [tn[i] for i in keys (tn)]
1415 network = collect .(inds .(ts))
1516 # Converting dims to Float64 to minimize overflow issues
1617 inds_to_dims = Dict (i => Float64 (length (i)) for i in unique (reduce (vcat, network)))
17- seq , _ = optimaltree (network, inds_to_dims)
18+ order , _ = optimaltree (network, inds_to_dims)
1819 # TODO : Map the integer indices back to the original tensor network vertices.
19- expr = contraction_order_to_expr (seq)
20- subs = Dict (symnameddims (i) => symnameddims (eachindex (tn)[i]) for i in eachindex (ts))
20+ expr = contraction_order_to_expr (order)
21+ verts = collect (keys (tn))
22+ sym (i) = symnameddims (verts[i], Tuple (inds (tn[verts[i]])))
23+ subs = Dict (symnameddims (i) => sym (i) for i in eachindex (verts))
2124 return substitute (expr, subs)
2225end
2326
0 commit comments