Skip to content

Commit 5bcd80f

Browse files
committed
Refactor TensorNetwork constructors
1 parent 8a250a0 commit 5bcd80f

11 files changed

+91
-70
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations"
3232
AbstractTrees = "0.4.5"
3333
Adapt = "4.3"
3434
BackendSelection = "0.1.6"
35-
Combinatorics = "1.0.3"
35+
Combinatorics = "1"
3636
DataGraphs = "0.2.7"
3737
DiagonalArrays = "0.3.23"
3838
Dictionaries = "0.4.5"

src/LazyNamedDimsArrays/evaluation_order.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,15 @@ function optimize_evaluation_order(
9191
return optimize_evaluation_order(alg, a)
9292
end
9393

94-
struct Eager end
95-
96-
default_optimize_evaluation_order_alg(a) = Eager()
94+
using BackendSelection: @Algorithm_str, Algorithm
95+
default_optimize_evaluation_order_alg(a) = Algorithm"eager"()
9796

9897
function optimize_contraction_order_flattened(alg, a)
99-
return error("Alg $alg not supported.")
98+
return error("`alg = $alg` not supported.")
10099
end
101100

102101
using Combinatorics: combinations
103-
function optimize_contraction_order_flattened(alg::Eager, a)
102+
function optimize_contraction_order_flattened(alg::Algorithm"eager", a)
104103
@assert ismul(a)
105104
arity(a) in (1, 2) && return a
106105
a1, a2 = argmin(combinations(arguments(a), 2)) do (a1, a2)

src/LazyNamedDimsArrays/lazynameddimsarray.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@ using WrappedUnions: @wrapped
66
} <: AbstractNamedDimsArray{T, Any}
77
union::Union{A, Mul{LazyNamedDimsArray{T, A}}}
88
end
9+
10+
parenttype(::Type{LazyNamedDimsArray{<:Any, A}}) where {A} = A
11+
parenttype(::Type{LazyNamedDimsArray{T}}) where {T} = AbstractNamedDimsArray{T}
12+
parenttype(::Type{LazyNamedDimsArray}) = AbstractNamedDimsArray
13+
914
function LazyNamedDimsArray(a::AbstractNamedDimsArray)
1015
# Use `eltype(typeof(a))` for arrays that have different
1116
# runtime and compile time eltypes, like `ITensor`.
1217
return LazyNamedDimsArray{eltype(typeof(a)), typeof(a)}(a)
1318
end
14-
function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A}
15-
return LazyNamedDimsArray{T, A}(a)
19+
function LazyNamedDimsArray(a::Mul{L}) where {L <: LazyNamedDimsArray}
20+
return LazyNamedDimsArray{eltype(L), parenttype(L)}(a)
1621
end
1722
lazy(a::LazyNamedDimsArray) = a
1823
lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a)

src/LazyNamedDimsArrays/symbolicarray.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@ end
99
function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}})
1010
return SymbolicArray{Any}(name, ax)
1111
end
12-
function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T}
13-
return SymbolicArray{T}(name, ax)
14-
end
15-
function SymbolicArray(name, ax::AbstractUnitRange...)
16-
return SymbolicArray{Any}(name, ax)
17-
end
1812
symname(a::SymbolicArray) = getfield(a, :name)
1913
Base.axes(a::SymbolicArray) = getfield(a, :axes)
2014
Base.size(a::SymbolicArray) = length.(axes(a))

src/LazyNamedDimsArrays/symbolicnameddimsarray.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ using NamedDimsArrays: NamedDimsArray, dename, inds
22

33
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
44
NamedDimsArray{T, N, Parent, DimNames}
5-
function symnameddims(name)
6-
return lazy(NamedDimsArray(SymbolicArray(name), ()))
5+
function symnameddims(name, dims)
6+
return lazy(NamedDimsArray(SymbolicArray(name, dename.(dims)), dims))
77
end
8+
symnameddims(name) = symnameddims(name, ())
89
using AbstractTrees: AbstractTrees
910
function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray)
1011
print(io, symname(dename(a)))

src/abstracttensornetwork.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Base.copy(tn::AbstractTensorNetwork) = error("Not implemented")
4040

4141
# Iteration
4242
Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...)
43+
Base.keys(tn::AbstractTensorNetwork) = vertices(tn)
4344

4445
# TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition,
4546
# where it is defined as the `vertextype`. Does that cause problems or should it be changed?

src/contract_network.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,50 @@
11
using BackendSelection: @Algorithm_str, Algorithm
22
using Base.Broadcast: materialize
3-
using ITensorNetworksNext.LazyNamedDimsArrays: lazy, substitute, symnameddims
3+
using ITensorNetworksNext.LazyNamedDimsArrays: lazy, optimize_evaluation_order, substitute,
4+
symnameddims
45

5-
# This is based on `MatrixAlgebraKit.select_algorithm`.
6+
# This is related to `MatrixAlgebraKit.select_algorithm`.
67
# TODO: Define this in BackendSelection.jl.
7-
function select_algorithm(alg; kwargs...)
8-
if alg isa Algorithm
9-
@assert isempty(kwargs) "Cannot pass keyword arguments when `alg` is an `Algorithm`."
10-
return alg
11-
else
12-
return Algorithm(alg; kwargs...)
13-
end
8+
backend_value(::Algorithm{alg}) where {alg} = alg
9+
using BackendSelection: parameters
10+
function merge_parameters(alg::Algorithm; kwargs...)
11+
return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...)
1412
end
13+
to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...)
14+
to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...)
1515

1616
# `contract_network`
1717
contract_network(alg::Algorithm, tn) = error("Not implemented.")
1818
function default_kwargs(::typeof(contract_network), tn)
1919
return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"()))
2020
end
2121
function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...)
22-
return contract_network(select_algorithm(alg; kwargs...), tn)
22+
return contract_network(to_algorithm(alg; kwargs...), tn)
2323
end
2424

2525
# `contract_network(::Algorithm"exact", ...)`
2626
function contract_network(alg::Algorithm"exact", tn)
2727
order = @something begin
2828
get(alg, :order, nothing)
29-
contraction_order(tn; alg = get(alg, :order_alg, default_kwargs(contraction_order, tn).alg))
29+
contraction_order(
30+
tn; alg = get(alg, :order_alg, default_kwargs(contraction_order, tn).alg)
31+
)
3032
end
31-
syms_to_ts = Dict(symnameddims(i) => lazy(tn[i]) for i in eachindex(tn))
33+
syms_to_ts = Dict(symnameddims(i, Tuple(inds(tn[i]))) => lazy(tn[i]) for i in eachindex(tn))
3234
tn_expression = substitute(order, syms_to_ts)
3335
return materialize(tn_expression)
3436
end
3537

3638
# `contraction_order`
37-
contraction_order(alg::Algorithm, tn) = error("Not implemented.")
39+
function contraction_order end
3840
default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"())
3941
function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...)
40-
return contraction_order(select_algorithm(alg; kwargs...), tn)
42+
return contraction_order(to_algorithm(alg; kwargs...), tn)
4143
end
42-
43-
# `contraction_order(::Algorithm"eager", ...)`
44-
function contraction_order(alg::Algorithm"eager", tn)
45-
return error("Eager not implemented.")
44+
function contraction_order(alg::Algorithm"left_associative", tn)
45+
return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), eachindex(tn))
46+
end
47+
function contraction_order(alg::Algorithm, tn)
48+
s = contraction_order(tn; alg = Algorithm"left_associative"())
49+
return optimize_evaluation_order(s; alg)
4650
end

src/tensornetwork.jl

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
using Combinatorics: combinations
12
using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph
23
using Dictionaries: AbstractDictionary, Indices, dictionary
34
using Graphs: AbstractSimpleGraph
45
using NamedDimsArrays: AbstractNamedDimsArray, dimnames
56
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype
6-
using NamedGraphs.GraphsExtensions: arranged_edges, vertextype
7+
using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype
78

89
function _TensorNetwork end
910

@@ -18,45 +19,65 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar
1819
return new{V, VD, UG, Tensors}(underlying_graph, tensors)
1920
end
2021
end
22+
# This assumes the tensor connectivity matches the graph structure.
23+
function _TensorNetwork(graph::AbstractGraph, tensors)
24+
return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors)))
25+
end
2126

2227
DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph)
2328
DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors)
2429
function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork})
2530
return fieldtype(type, :underlying_graph)
2631
end
2732

28-
# Determine the graph structure from the tensors.
29-
function TensorNetwork(t::AbstractDictionary)
30-
g = NamedGraph(eachindex(t))
31-
for v1 in vertices(g)
32-
for v2 in vertices(g)
33-
if v1 v2
34-
if !isdisjoint(dimnames(t[v1]), dimnames(t[v2]))
35-
add_edge!(g, v1 => v2)
36-
end
33+
# For a collection of tensors, return the edges implied by shared indices
34+
# as a list of `edgetype` edges of keys/vertices.
35+
function tensornetwork_edges(edgetype::Type, tensors)
36+
# We need to collect the keys since in the case of `tensors::AbstractDictionary`,
37+
# `keys(tensors)::AbstractIndices`, which is indexed by `keys(tensors)` rather
38+
# than `1:length(keys(tensors))`, which is assumed by `combinations`.
39+
verts = collect(keys(tensors))
40+
return filter(
41+
!isnothing, map(combinations(verts, 2)) do (v1, v2)
42+
if !isdisjoint(inds(tensors[v1]), inds(tensors[v2]))
43+
return arrange_edge(edgetype(v1, v2))
3744
end
45+
return nothing
3846
end
39-
end
40-
return _TensorNetwork(g, t)
47+
)
4148
end
42-
function TensorNetwork(tensors::AbstractDict)
43-
return TensorNetwork(Dictionary(tensors))
49+
tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors)
50+
51+
function TensorNetwork(f::Base.Callable, graph::AbstractGraph)
52+
tensors = Dictionary(vertices(graph), f.(vertices(graph)))
53+
return TensorNetwork(graph, tensors)
54+
end
55+
function TensorNetwork(graph::AbstractGraph, tensors)
56+
tn = _TensorNetwork(graph, tensors)
57+
fix_links!(tn)
58+
return tn
4459
end
4560

46-
function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary)
47-
tn = TensorNetwork(tensors)
48-
arranged_edges(tn) arranged_edges(graph) ||
61+
# Insert trivial links for missing edges, and also check
62+
# the vertices and edges are consistent between the graph and tensors.
63+
function fix_links!(tn::AbstractTensorNetwork)
64+
graph = underlying_graph(tn)
65+
tensors = vertex_data(tn)
66+
@assert issetequal(vertices(graph), keys(tensors)) "Graph vertices and tensor keys must match."
67+
tn_edges = tensornetwork_edges(edgetype(graph), tensors)
68+
tn_edges arranged_edges(graph) ||
4969
error("The edges in the tensors do not match the graph structure.")
50-
for e in setdiff(arranged_edges(graph), arranged_edges(tn))
70+
for e in setdiff(arranged_edges(graph), tn_edges)
5171
insert_trivial_link!(tn, e)
5272
end
5373
return tn
5474
end
55-
function TensorNetwork(graph::AbstractGraph, tensors::AbstractDict)
56-
return TensorNetwork(graph, Dictionary(tensors))
57-
end
58-
function TensorNetwork(f, graph::AbstractGraph)
59-
return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph)))
75+
76+
# Determine the graph structure from the tensors.
77+
function TensorNetwork(tensors)
78+
graph = NamedGraph(keys(tensors))
79+
add_edges!(graph, tensornetwork_edges(tensors))
80+
return _TensorNetwork(graph, tensors)
6081
end
6182

6283
function Base.copy(tn::TensorNetwork)
@@ -65,10 +86,9 @@ end
6586
TensorNetwork(tn::TensorNetwork) = copy(tn)
6687
TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn)
6788
function TensorNetwork{V}(tn::TensorNetwork) where {V}
68-
g′ = convert_vertextype(V, underlying_graph(tn))
69-
d = vertex_data(tn)
70-
d′ = dictionary(V(k) => d[k] for k in eachindex(d))
71-
return TensorNetwork(g′, d′)
89+
g = convert_vertextype(V, underlying_graph(tn))
90+
d = dictionary(V(k) => tn[k] for k in keys(d))
91+
return TensorNetwork(g, d)
7292
end
7393

7494
NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn

test/test_contract_network.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ using Graphs: edges
22
using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges
33
using NamedGraphs.NamedGraphGenerators: named_grid
44
using ITensorBase: Index, ITensor
5-
using ITensorNetworksNext:
6-
TensorNetwork, linkinds, siteinds, contract_network
5+
using ITensorNetworksNext: TensorNetwork, linkinds, siteinds, contract_network
76
using TensorOperations: TensorOperations
87
using Test: @test, @testset
98

@@ -15,8 +14,8 @@ using Test: @test, @testset
1514
C = ITensor([5.0, 1.0], j)
1615
D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k)
1716

18-
ABCD_1 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "leftassociative")
19-
ABCD_2 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "optimal")
17+
ABCD_1 = contract_network([A, B, C, D]; alg = "exact", order_alg = "leftassociative")
18+
ABCD_2 = contract_network([A, B, C, D]; alg = "exact", order_alg = "optimal")
2019

2120
@test ABCD_1 == ABCD_2
2221
end

test/test_lazynameddimsarrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ using WrappedUnions: unwrap
9595
@test unwrap(a1) isa NamedDimsArray
9696
@test dename(a1) isa SymbolicArray
9797
@test dename(unwrap(a1)) isa SymbolicArray
98-
@test dename(unwrap(a1)) == SymbolicArray(:a1)
98+
@test dename(unwrap(a1)) == SymbolicArray(:a1, ())
9999
@test inds(a1) == ()
100100
@test dimnames(a1) == ()
101101

0 commit comments

Comments
 (0)