Skip to content

Commit 1eeac5c

Browse files
committed
Progress using eager evaluation in contract_network
1 parent 84e7454 commit 1eeac5c

File tree

1 file changed

+41
-32
lines changed

1 file changed

+41
-32
lines changed

src/contract_network.jl

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,56 @@
11
using BackendSelection: @Algorithm_str, Algorithm
2-
using ITensorNetworksNext.LazyNamedDimsArrays: substitute, materialize, lazy,
3-
symnameddims
4-
5-
#Algorithmic defaults
6-
default_sequence_alg(::Algorithm"exact") = "leftassociative"
7-
default_sequence(::Algorithm"exact") = nothing
8-
function set_default_kwargs(alg::Algorithm"exact")
9-
sequence = get(alg, :sequence, nothing)
10-
sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg))
11-
return Algorithm("exact"; sequence, sequence_alg)
12-
end
13-
14-
function contraction_sequence_to_expr(seq)
15-
if seq isa AbstractVector
16-
return prod(contraction_sequence_to_expr, seq)
2+
using Base.Broadcast: materialize
3+
using ITensorNetworksNext.LazyNamedDimsArrays: lazy, substitute, symnameddims
4+
5+
# This is based on `MatrixAlgebraKit.select_algorithm`.
6+
# TODO: Define this in BackendSelection.jl.
7+
function select_algorithm(alg; kwargs...)
8+
if alg isa Algorithm
9+
@assert isempty(kwargs) "Cannot pass keyword arguments when `alg` is an `Algorithm`."
10+
return alg
1711
else
18-
return symnameddims(seq)
12+
return Algorithm(alg; kwargs...)
1913
end
2014
end
2115

22-
function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray})
23-
return prod(symnameddims, 1:length(tn))
16+
# TODO: Use more general `default_kwargs(...)` here?
17+
function default_kwargs(::typeof(contract_network), tn)
18+
return (; alg = Algorithm"exact"(; evaluation_order_alg = Algorithm"eager"()))
19+
end
20+
function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...)
21+
return contract_network(select_algorithm(alg; kwargs...), tn)
2422
end
2523

26-
function contraction_sequence(tn::Vector{<:AbstractArray}; sequence_alg = default_sequence_alg(Algorithm("exact")))
27-
return contraction_sequence(Algorithm(sequence_alg), tn)
24+
function contract_network(alg::Algorithm, tn::AbstractTensorNetwork)
25+
return error("Not implemented.")
26+
end
27+
function contract_network(alg::Algorithm"exact", tn)
28+
evaluation_order = @something begin
29+
get(alg, :evaluation_order, nothing)
30+
contraction_order(tn; alg = alg.evaluation_order_alg)
31+
end
32+
syms_to_ts = Dict(symnameddims(i) => lazy(tn[i]) for i in eachindex(tn))
33+
tn_expression = substitute(evaluation_order, syms_to_ts)
34+
return materialize(tn_expression)
2835
end
2936

30-
function contract_network(alg::Algorithm"exact", tn::Vector{<:AbstractArray})
31-
if !isnothing(alg.sequence)
32-
sequence = alg.sequence
37+
# TODO: Move to TensorOperationsExt.
38+
function contraction_order_to_expr(seq)
39+
if seq isa AbstractVector
40+
return prod(contraction_order_to_expr, seq)
3341
else
34-
sequence = contraction_sequence(tn; sequence_alg = alg.sequence_alg)
42+
return symnameddims(seq)
3543
end
36-
37-
sequence = substitute(sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn)))
38-
return materialize(sequence)
3944
end
4045

41-
function contract_network(alg::Algorithm"exact", tn::AbstractTensorNetwork)
42-
return contract_network(alg, [tn[v] for v in vertices(tn)])
46+
# TODO: Use more general `default_kwargs(...)` here?
47+
default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"())
48+
function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...)
49+
return contraction_order(select_algorithm(alg; kwargs...), tn)
4350
end
44-
45-
function contract_network(tn; alg, kwargs...)
46-
return contract_network(set_default_kwargs(Algorithm(alg; kwargs...)), tn)
51+
function contraction_order(alg::Algorithm, tn)
52+
return error("Not implemented.")
53+
end
54+
function contraction_order(alg::Algorithm"eager", tn)
55+
return error("Eager not implemented.")
4756
end

0 commit comments

Comments
 (0)