|
1 | 1 | using BackendSelection: @Algorithm_str, Algorithm |
2 | | -using ITensorNetworksNext.LazyNamedDimsArrays: nested_array_to_lazy_multiply, substitute_lazy, materialize, lazy, |
| 2 | +using ITensorNetworksNext.LazyNamedDimsArrays: substitute, materialize, lazy, |
3 | 3 | symnameddims |
4 | 4 |
|
5 | | -default_contract_alg = nothing |
6 | | - |
7 | 5 | #Algorithmic defaults |
8 | 6 | default_sequence(::Algorithm"exact") = "leftassociative" |
9 | 7 | function set_default_kwargs(alg::Algorithm"exact") |
10 | 8 | sequence = get(alg, :sequence, default_sequence(alg)) |
11 | 9 | return Algorithm("exact"; sequence) |
12 | 10 | end |
13 | 11 |
|
| 12 | +function contraction_sequence_to_expr(seq) |
| 13 | + if seq isa AbstractVector |
| 14 | + return prod(contraction_sequence_to_expr, seq) |
| 15 | + else |
| 16 | + return symnameddims(seq) |
| 17 | + end |
| 18 | +end |
| 19 | + |
14 | 20 | function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray}) |
15 | | - return nested_array_to_lazy_multiply(collect.(1:length(tn))) |
| 21 | + return contraction_sequence_to_expr(collect.(1:length(tn))) |
16 | 22 | end |
17 | 23 |
|
18 | 24 | function contraction_sequence(tn::Vector{<:AbstractArray}; alg = default_sequence_alg) |
|
21 | 27 |
|
22 | 28 | function contractnetwork(alg::Algorithm"exact", tn::Vector{<:AbstractArray}) |
23 | 29 | contract_sequence = isa(alg.sequence, String) ? contraction_sequence(tn; alg = alg.sequence) : sequence |
24 | | - contract_sequence = substitute_lazy(contract_sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn))) |
| 30 | + contract_sequence = substitute(contract_sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn))) |
25 | 31 | return materialize(contract_sequence) |
26 | 32 | end |
27 | 33 |
|
28 | 34 | function contractnetwork(alg::Algorithm"exact", tn::AbstractTensorNetwork) |
29 | 35 | return contractnetwork(alg, [tn[v] for v in vertices(tn)]) |
30 | 36 | end |
31 | 37 |
|
32 | | -function contractnetwork(tn::Union{AbstractTensorNetwork, Vector{<:AbstractArray}}; alg = default_contract_alg, kwargs...) |
33 | | - alg == nothing && error("Must specify an algorithm to contract the network with") |
| 38 | +function contractnetwork(tn; alg, kwargs...) |
34 | 39 | return contractnetwork(set_default_kwargs(Algorithm(alg; kwargs...)), tn) |
35 | 40 | end |
0 commit comments