1- using TensorOperations: TensorOperations, optimaltree
21using ITensorBase: inds, dim
32
4- default_sequence_alg = " optimal"
5-
6- function contraction_sequence (:: Algorithm"optimal" , tn:: Vector{<:AbstractArray} )
7- network = collect .(inds .(tn))
8- # Converting dims to Float64 to minimize overflow issues
9- inds_to_dims = Dict (i => Float64 (dim (i)) for i in unique (reduce (vcat, network)))
10- seq, _ = optimaltree (network, inds_to_dims)
11- return seq
12- end
3+ default_sequence_alg = " leftassociative"
134
145function contraction_sequence (:: Algorithm"leftassociative" , tn:: Vector{<:AbstractArray} )
156 return Any[i for i in 1 : length (tn)]
@@ -20,20 +11,20 @@ function contraction_sequence(tn::Vector{<:AbstractArray}; alg=default_sequence_
2011end
2112
2213# Internal recursive worker
23- function recursive_contractnetwork (tn:: Union{AbstractVector,AbstractArray } )
24- tn isa AbstractVector && return reduce ( * , map ( recursive_contractnetwork, tn) )
14+ function recursive_contractnetwork (tn:: Union{AbstractVector,AbstractNamedDimsArray } )
15+ tn isa AbstractVector && return prod ( recursive_contractnetwork, tn)
2516 return tn
2617end
2718
2819# Recursive worker for ordering the tensors according to the sequence
2920rearrange (tn:: Vector{<:AbstractArray} , i:: Integer ) = tn[i]
3021rearrange (tn:: Vector{<:AbstractArray} , v:: AbstractVector ) = [rearrange (tn, s) for s in v]
3122
32- function contractnetwork (tn:: Vector{<:AbstractArray} ; sequence_alg = default_sequence_alg)
33- sequence = contraction_sequence (tn; alg= sequence_alg)
34- return recursive_contractnetwork (rearrange (tn, sequence ))
23+ function contractnetwork (tn:: Vector{<:AbstractArray} ; sequence = default_sequence_alg)
24+ contract_sequence = isa (sequence, String) ? contraction_sequence (tn; alg= sequence) : sequence
25+ return recursive_contractnetwork (rearrange (tn, contract_sequence ))
3526end
3627
37- function contractnetwork (tn:: AbstractTensorNetwork ; sequence_alg = default_sequence_alg)
38- return contractnetwork ([tn[v] for v in vertices (tn)]; sequence_alg )
28+ function contractnetwork (tn:: AbstractTensorNetwork ; sequence = default_sequence_alg)
29+ return contractnetwork ([tn[v] for v in vertices (tn)]; sequence )
3930end
0 commit comments