|
1 | 1 | using BackendSelection: @Algorithm_str, Algorithm |
2 | | -using ITensorNetworksNext.LazyNamedDimsArrays: substitute, materialize, lazy, |
3 | | - symnameddims |
4 | | - |
5 | | -#Algorithmic defaults |
6 | | -default_sequence_alg(::Algorithm"exact") = "leftassociative" |
7 | | -default_sequence(::Algorithm"exact") = nothing |
8 | | -function set_default_kwargs(alg::Algorithm"exact") |
9 | | - sequence = get(alg, :sequence, nothing) |
10 | | - sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) |
11 | | - return Algorithm("exact"; sequence, sequence_alg) |
12 | | -end |
13 | | - |
14 | | -function contraction_sequence_to_expr(seq) |
15 | | - if seq isa AbstractVector |
16 | | - return prod(contraction_sequence_to_expr, seq) |
| 2 | +using Base.Broadcast: materialize |
| 3 | +using ITensorNetworksNext.LazyNamedDimsArrays: lazy, substitute, symnameddims |
| 4 | + |
| 5 | +# This is based on `MatrixAlgebraKit.select_algorithm`. |
| 6 | +# TODO: Define this in BackendSelection.jl. |
| 7 | +function select_algorithm(alg; kwargs...) |
| 8 | + if alg isa Algorithm |
| 9 | + @assert isempty(kwargs) "Cannot pass keyword arguments when `alg` is an `Algorithm`." |
| 10 | + return alg |
17 | 11 | else |
18 | | - return symnameddims(seq) |
| 12 | + return Algorithm(alg; kwargs...) |
19 | 13 | end |
20 | 14 | end |
21 | 15 |
|
22 | | -function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray}) |
23 | | - return prod(symnameddims, 1:length(tn)) |
| 16 | +# TODO: Use more general `default_kwargs(...)` here? |
| 17 | +function default_kwargs(::typeof(contract_network), tn) |
| 18 | + return (; alg = Algorithm"exact"(; evaluation_order_alg = Algorithm"eager"())) |
| 19 | +end |
| 20 | +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) |
| 21 | + return contract_network(select_algorithm(alg; kwargs...), tn) |
24 | 22 | end |
25 | 23 |
|
26 | | -function contraction_sequence(tn::Vector{<:AbstractArray}; sequence_alg = default_sequence_alg(Algorithm("exact"))) |
27 | | - return contraction_sequence(Algorithm(sequence_alg), tn) |
| 24 | +function contract_network(alg::Algorithm, tn::AbstractTensorNetwork) |
| 25 | + return error("Not implemented.") |
| 26 | +end |
| 27 | +function contract_network(alg::Algorithm"exact", tn) |
| 28 | + evaluation_order = @something begin |
| 29 | + get(alg, :evaluation_order, nothing) |
| 30 | + contraction_order(tn; alg = alg.evaluation_order_alg) |
| 31 | + end |
| 32 | + syms_to_ts = Dict(symnameddims(i) => lazy(tn[i]) for i in eachindex(tn)) |
| 33 | + tn_expression = substitute(evaluation_order, syms_to_ts) |
| 34 | + return materialize(tn_expression) |
28 | 35 | end |
29 | 36 |
|
30 | | -function contract_network(alg::Algorithm"exact", tn::Vector{<:AbstractArray}) |
31 | | - if !isnothing(alg.sequence) |
32 | | - sequence = alg.sequence |
| 37 | +# TODO: Move to TensorOperationsExt. |
| 38 | +function contraction_order_to_expr(seq) |
| 39 | + if seq isa AbstractVector |
| 40 | + return prod(contraction_order_to_expr, seq) |
33 | 41 | else |
34 | | - sequence = contraction_sequence(tn; sequence_alg = alg.sequence_alg) |
| 42 | + return symnameddims(seq) |
35 | 43 | end |
36 | | - |
37 | | - sequence = substitute(sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn))) |
38 | | - return materialize(sequence) |
39 | 44 | end |
40 | 45 |
|
41 | | -function contract_network(alg::Algorithm"exact", tn::AbstractTensorNetwork) |
42 | | - return contract_network(alg, [tn[v] for v in vertices(tn)]) |
| 46 | +# TODO: Use more general `default_kwargs(...)` here? |
| 47 | +default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) |
| 48 | +function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) |
| 49 | + return contraction_order(select_algorithm(alg; kwargs...), tn) |
43 | 50 | end |
44 | | - |
45 | | -function contract_network(tn; alg, kwargs...) |
46 | | - return contract_network(set_default_kwargs(Algorithm(alg; kwargs...)), tn) |
| 51 | +function contraction_order(alg::Algorithm, tn) |
| 52 | + return error("Not implemented.") |
| 53 | +end |
| 54 | +function contraction_order(alg::Algorithm"eager", tn) |
| 55 | + return error("Eager not implemented.") |
47 | 56 | end |
0 commit comments