11using BackendSelection: @Algorithm_str , Algorithm
22using Base. Broadcast: materialize
3- using ITensorNetworksNext. LazyNamedDimsArrays: lazy, optimize_evaluation_order, substitute ,
4- symnameddims
3+ using ITensorNetworksNext. LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order ,
4+ substitute, symnameddims
55
66# This is related to `MatrixAlgebraKit.select_algorithm`.
77# TODO : Define this in BackendSelection.jl.
@@ -14,7 +14,9 @@ to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...)
1414to_algorithm (alg; kwargs... ) = Algorithm (alg; kwargs... )
1515
1616# `contract_network`
17- contract_network (alg:: Algorithm , tn) = error (" Not implemented." )
17+ function contract_network (alg:: Algorithm , tn)
18+ return throw (ArgumentError (" `contract_network` algorithm `$(alg) ` not implemented." ))
19+ end
1820function default_kwargs (:: typeof (contract_network), tn)
1921 return (; alg = Algorithm " exact" (; order_alg = Algorithm " eager" ()))
2022end
@@ -23,14 +25,25 @@ function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kw
2325end
2426
2527# `contract_network(::Algorithm"exact", ...)`
26- function contract_network (alg:: Algorithm"exact" , tn)
27- order = @something begin
28- get (alg, :order , nothing )
29- contraction_order (
30- tn; alg = get (alg, :order_alg , default_kwargs (contraction_order, tn). alg)
31- )
28+ function get_order (alg:: Algorithm"exact" , tn)
29+ # Allow specifying either `order` or `order_alg`.
30+ order = get (alg, :order , nothing )
31+ order = if ! isnothing (order)
32+ order
33+ else
34+ default_order_alg = default_kwargs (contraction_order, tn). alg
35+ order_alg = get (alg, :order_alg , default_order_alg)
36+ # TODO : Capture other keyword arguments and pass them to `contraction_order`.
37+ contraction_order (tn; alg = order_alg)
3238 end
33- syms_to_ts = Dict (symnameddims (i, Tuple (inds (tn[i]))) => lazy (tn[i]) for i in eachindex (tn))
39+ # Contraction order may or may not have indices attached, canonicalize the format
40+ # by attaching indices.
41+ subs = Dict (symnameddims (i) => symnameddims (i, Tuple (inds (tn[i]))) for i in keys (tn))
42+ return substitute (order, subs)
43+ end
44+ function contract_network (alg:: Algorithm"exact" , tn)
45+ order = get_order (alg, tn)
46+ syms_to_ts = Dict (symnameddims (i, Tuple (inds (tn[i]))) => lazy (tn[i]) for i in keys (tn))
3447 tn_expression = substitute (order, syms_to_ts)
3548 return materialize (tn_expression)
3649end
@@ -41,10 +54,16 @@ default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"())
4154function contraction_order (tn; alg = default_kwargs (contraction_order, tn). alg, kwargs... )
4255 return contraction_order (to_algorithm (alg; kwargs... ), tn)
4356end
57+ # Convert the tensor network to a flat symbolic multiplication expression.
58+ function contraction_order (alg:: Algorithm"flat" , tn)
59+ syms = [symnameddims (i, Tuple (inds (tn[i]))) for i in keys (tn)]
60+ # Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`.
61+ return lazy (Mul (syms))
62+ end
4463function contraction_order (alg:: Algorithm"left_associative" , tn)
4564 return prod (i -> symnameddims (i, Tuple (inds (tn[i]))), keys (tn))
4665end
4766function contraction_order (alg:: Algorithm , tn)
48- s = contraction_order (tn; alg = Algorithm "left_associative " ())
67+ s = contraction_order (Algorithm "flat " (), tn )
4968 return optimize_evaluation_order (s; alg)
5069end
0 commit comments