Skip to content

Commit 08bd968

Browse files
committed
Small refactor
1 parent cfd425a commit 08bd968

File tree

4 files changed

+62
-46
lines changed

4 files changed

+62
-46
lines changed
Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
module ITensorNetworksNextTensorOperationsExt
22

33
using BackendSelection: @Algorithm_str, Algorithm
4-
using ITensorNetworksNext: ITensorNetworksNext, contraction_order
5-
using ITensorNetworksNext.LazyNamedDimsArrays: symnameddims, substitute
4+
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, ismul, symnameddims,
5+
substitute
6+
using ITensorNetworksNext.LazyNamedDimsArrays.TermInterface: arguments
67
using NamedDimsArrays: inds
78
using TensorOperations: TensorOperations, optimaltree
89

9-
function contraction_order_to_expr(ord)
10-
return ord isa AbstractVector ? prod(contraction_order_to_expr, ord) : symnameddims(ord)
10+
function contraction_tree_to_expr(f, tree)
11+
return if !(tree isa AbstractVector)
12+
f(tree)
13+
else
14+
prod(Base.Fix1(contraction_tree_to_expr, f), tree)
15+
end
1116
end
1217

13-
function ITensorNetworksNext.contraction_order(alg::Algorithm"optimal", tn)
14-
ts = [tn[i] for i in keys(tn)]
15-
network = collect.(inds.(ts))
18+
function LazyNamedDimsArrays.optimize_contraction_order(alg::Algorithm"optimal", a)
19+
@assert ismul(a)
20+
ts = arguments(a)
21+
inds_network = collect.(inds.(ts))
1622
# Converting dims to Float64 to minimize overflow issues
17-
inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network)))
18-
order, _ = optimaltree(network, inds_to_dims)
19-
# TODO: Map the integer indices back to the original tensor network vertices.
20-
expr = contraction_order_to_expr(order)
21-
verts = collect(keys(tn))
22-
sym(i) = symnameddims(verts[i], Tuple(inds(tn[verts[i]])))
23-
subs = Dict(symnameddims(i) => sym(i) for i in eachindex(verts))
24-
return substitute(expr, subs)
23+
inds_to_dims = Dict(i => Float64(length(i)) for i in reduce(, inds_network))
24+
tree, _ = optimaltree(inds_network, inds_to_dims)
25+
return contraction_tree_to_expr(i -> ts[i], tree)
2526
end
2627

2728
end

src/LazyNamedDimsArrays/evaluation_order.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,26 +61,20 @@ function flatten_expression(a)
6161
if !iscall(a)
6262
return a
6363
elseif ismul(a)
64-
flattened_arguments = mapreduce(vcat, arguments(a)) do arg
65-
return ismul(arg) ? arguments(arg) : [arg]
66-
end
64+
flattened_arguments = mapreduce(to_mul_arguments, vcat, arguments(a))
6765
return lazy(Mul(flattened_arguments))
6866
else
6967
return error("Variant not supported.")
7068
end
7169
end
7270

7371
function optimize_evaluation_order(alg, a)
74-
return optimize_evaluation_order_flattened(alg, flatten_expression(a))
75-
end
76-
77-
function optimize_evaluation_order_flattened(alg, a)
7872
if !iscall(a)
7973
return a
8074
elseif ismul(a)
81-
return optimize_contraction_order_flattened(alg, a)
75+
return optimize_contraction_order(alg, a)
8276
else
83-
# TODO: Recurse into other operations, calling `optimize_evaluation_order_flattened`.
77+
# TODO: Recurse into other operations, calling `optimize_evaluation_order`.
8478
return error("Variant not supported.")
8579
end
8680
end
@@ -94,19 +88,21 @@ end
9488
using BackendSelection: @Algorithm_str, Algorithm
9589
default_optimize_evaluation_order_alg(a) = Algorithm"eager"()
9690

97-
function optimize_contraction_order_flattened(alg, a)
91+
function optimize_contraction_order(alg, a)
9892
return error("`alg = $alg` not supported.")
9993
end
10094

10195
using Combinatorics: combinations
102-
function optimize_contraction_order_flattened(alg::Algorithm"eager", a)
96+
function optimize_contraction_order(alg::Algorithm"eager", a)
10397
@assert ismul(a)
10498
arity(a) in (1, 2) && return a
10599
a1, a2 = argmin(combinations(arguments(a), 2)) do (a1, a2)
106100
# Penalize outer product contractions.
101+
# TODO: Still order the outer products by time complexity,
102+
# say by checking if there are only outer products left.
107103
isdisjoint(inds(a1), inds(a2)) && return typemax(Int)
108104
return time_complexity(*, a1, a2)
109105
end
110106
contracted_arguments = [filter(((a1, a2)), arguments(a)); [a1 * a2]]
111-
return optimize_contraction_order_flattened(alg, lazy(Mul(contracted_arguments)))
107+
return optimize_contraction_order(alg, lazy(Mul(contracted_arguments)))
112108
end

src/LazyNamedDimsArrays/lazyinterface.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,8 @@ end
157157
function show_lazy(io::IO, mime::MIME"text/plain", a)
158158
summary(io, a)
159159
println(io, ":")
160-
if !iscall(a)
161-
show(io, mime, unwrap(a))
162-
return nothing
163-
else
164-
show(io, a)
165-
return nothing
166-
end
160+
!iscall(a) ? show(io, mime, unwrap(a)) : show(io, a)
161+
return nothing
167162
end
168163
add_lazy(a1, a2) = error("Not implemented.")
169164
sub_lazy(a) = error("Not implemented.")
@@ -179,7 +174,12 @@ function mul_lazy(a)
179174
end
180175
end
181176
# Note that this is nested by default.
182-
mul_lazy(a1, a2) = lazy(Mul([a1, a2]))
177+
function mul_lazy(a1, a2; flatten::Bool = false)
178+
return flatten ? mul_lazy_flattened(a1, a2) : mul_lazy_nested(a1, a2)
179+
end
180+
mul_lazy_nested(a1, a2) = lazy(Mul([a1, a2]))
181+
to_mul_arguments(a) = ismul(a) ? arguments(a) : [a]
182+
mul_lazy_flattened(a1, a2) = lazy(Mul([to_mul_arguments(a1); to_mul_arguments(a2)]))
183183
mul_lazy(a1::Number, a2) = error("Not implemented.")
184184
mul_lazy(a1, a2::Number) = error("Not implemented.")
185185
mul_lazy(a1::Number, a2::Number) = a1 * a2

src/contract_network.jl

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using BackendSelection: @Algorithm_str, Algorithm
22
using Base.Broadcast: materialize
3-
using ITensorNetworksNext.LazyNamedDimsArrays: lazy, optimize_evaluation_order, substitute,
4-
symnameddims
3+
using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order,
4+
substitute, symnameddims
55

66
# This is related to `MatrixAlgebraKit.select_algorithm`.
77
# TODO: Define this in BackendSelection.jl.
@@ -14,7 +14,9 @@ to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...)
1414
to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...)
1515

1616
# `contract_network`
17-
contract_network(alg::Algorithm, tn) = error("Not implemented.")
17+
function contract_network(alg::Algorithm, tn)
18+
return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented."))
19+
end
1820
function default_kwargs(::typeof(contract_network), tn)
1921
return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"()))
2022
end
@@ -23,14 +25,25 @@ function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kw
2325
end
2426

2527
# `contract_network(::Algorithm"exact", ...)`
26-
function contract_network(alg::Algorithm"exact", tn)
27-
order = @something begin
28-
get(alg, :order, nothing)
29-
contraction_order(
30-
tn; alg = get(alg, :order_alg, default_kwargs(contraction_order, tn).alg)
31-
)
28+
function get_order(alg::Algorithm"exact", tn)
29+
# Allow specifying either `order` or `order_alg`.
30+
order = get(alg, :order, nothing)
31+
order = if !isnothing(order)
32+
order
33+
else
34+
default_order_alg = default_kwargs(contraction_order, tn).alg
35+
order_alg = get(alg, :order_alg, default_order_alg)
36+
# TODO: Capture other keyword arguments and pass them to `contraction_order`.
37+
contraction_order(tn; alg = order_alg)
3238
end
33-
syms_to_ts = Dict(symnameddims(i, Tuple(inds(tn[i]))) => lazy(tn[i]) for i in eachindex(tn))
39+
# Contraction order may or may not have indices attached, canonicalize the format
40+
# by attaching indices.
41+
subs = Dict(symnameddims(i) => symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn))
42+
return substitute(order, subs)
43+
end
44+
function contract_network(alg::Algorithm"exact", tn)
45+
order = get_order(alg, tn)
46+
syms_to_ts = Dict(symnameddims(i, Tuple(inds(tn[i]))) => lazy(tn[i]) for i in keys(tn))
3447
tn_expression = substitute(order, syms_to_ts)
3548
return materialize(tn_expression)
3649
end
@@ -41,10 +54,16 @@ default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"())
4154
function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...)
4255
return contraction_order(to_algorithm(alg; kwargs...), tn)
4356
end
57+
# Convert the tensor network to a flat symbolic multiplication expression.
58+
function contraction_order(alg::Algorithm"flat", tn)
59+
syms = [symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)]
60+
# Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`.
61+
return lazy(Mul(syms))
62+
end
4463
function contraction_order(alg::Algorithm"left_associative", tn)
4564
return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), keys(tn))
4665
end
4766
function contraction_order(alg::Algorithm, tn)
48-
s = contraction_order(tn; alg = Algorithm"left_associative"())
67+
s = contraction_order(Algorithm"flat"(), tn)
4968
return optimize_evaluation_order(s; alg)
5069
end

0 commit comments

Comments
 (0)