Skip to content

Commit d5ea09d

Browse files
authored
TensorNetwork type (#1)
1 parent eb06f66 commit d5ea09d

File tree

6 files changed

+456
-7
lines changed

6 files changed

+456
-7
lines changed

Project.toml

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,31 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
5+
6+
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
9+
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
10+
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
11+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
14+
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
15+
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
16+
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
17+
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
518

619
[compat]
20+
Adapt = "4.3.0"
21+
BackendSelection = "0.1.6"
22+
DataGraphs = "0.2.7"
23+
Dictionaries = "0.4.5"
24+
Graphs = "1.13.1"
25+
LinearAlgebra = "1.10"
26+
MacroTools = "0.5.16"
27+
NamedDimsArrays = "0.7.13"
28+
NamedGraphs = "0.6.9"
29+
SimpleTraits = "0.9.5"
30+
SplitApplyCombine = "1.2.3"
731
julia = "1.10"

src/ITensorNetworksNext.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ITensorNetworksNext
22

3-
# Write your package code here.
3+
include("abstracttensornetwork.jl")
4+
include("tensornetwork.jl")
45

56
end

src/abstracttensornetwork.jl

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
using Adapt: Adapt, adapt, adapt_structure
2+
using BackendSelection: @Algorithm_str, Algorithm
3+
using DataGraphs:
4+
DataGraphs,
5+
AbstractDataGraph,
6+
edge_data,
7+
underlying_graph,
8+
underlying_graph_type,
9+
vertex_data
10+
using Dictionaries: Dictionary
11+
using Graphs:
12+
Graphs,
13+
AbstractEdge,
14+
AbstractGraph,
15+
Graph,
16+
add_edge!,
17+
add_vertex!,
18+
bfs_tree,
19+
center,
20+
dst,
21+
edges,
22+
edgetype,
23+
ne,
24+
neighbors,
25+
nv,
26+
rem_edge!,
27+
src,
28+
vertices
29+
using LinearAlgebra: LinearAlgebra, factorize
30+
using MacroTools: @capture
31+
using NamedDimsArrays: dimnames
32+
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
33+
using NamedGraphs.GraphsExtensions:
34+
, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype
35+
using SplitApplyCombine: flatten
36+
37+
abstract type AbstractTensorNetwork{V,VD} <: AbstractDataGraph{V,VD,Nothing} end
38+
39+
function Graphs.rem_edge!(tn::AbstractTensorNetwork, e)
40+
rem_edge!(underlying_graph(tn), e)
41+
return tn
42+
end
43+
44+
# TODO: Define a generic fallback for `AbstractDataGraph`?
45+
DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data")
46+
47+
# Graphs.jl overloads
48+
function Graphs.weights(graph::AbstractTensorNetwork)
49+
V = vertextype(graph)
50+
es = Tuple.(edges(graph))
51+
ws = Dictionary{Tuple{V,V},Float64}(es, undef)
52+
for e in edges(graph)
53+
w = log2(dim(commoninds(graph, e)))
54+
ws[(src(e), dst(e))] = w
55+
end
56+
return ws
57+
end
58+
59+
# Copy
60+
Base.copy(tn::AbstractTensorNetwork) = error("Not implemented")
61+
62+
# Iteration
63+
Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...)
64+
65+
# TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition,
66+
# where it is defined as the `vertextype`. Does that cause problems or should it be changed?
67+
Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn))
68+
69+
# Overload if needed
70+
Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false
71+
72+
# Derived interface, may need to be overloaded
73+
function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork})
74+
return underlying_graph_type(data_graph_type(G))
75+
end
76+
77+
# AbstractDataGraphs overloads
78+
function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...)
79+
return error("Not implemented")
80+
end
81+
function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...)
82+
return error("Not implemented")
83+
end
84+
85+
DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented")
86+
function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork)
87+
return NamedGraphs.vertex_positions(underlying_graph(tn))
88+
end
89+
function NamedGraphs.ordered_vertices(tn::AbstractTensorNetwork)
90+
return NamedGraphs.ordered_vertices(underlying_graph(tn))
91+
end
92+
93+
function Adapt.adapt_structure(to, tn::AbstractTensorNetwork)
94+
# TODO: Define and use:
95+
#
96+
# @preserve_graph map_vertex_data(adapt(to), tn)
97+
#
98+
# or just:
99+
#
100+
# @preserve_graph map(adapt(to), tn)
101+
return map_vertex_data_preserve_graph(adapt(to), tn)
102+
end
103+
104+
function linkinds(tn::AbstractTensorNetwork, edge::Pair)
105+
return linkinds(tn, edgetype(tn)(edge))
106+
end
107+
function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge)
108+
return nameddimsindices(tn[src(edge)]) nameddimsindices(tn[dst(edge)])
109+
end
110+
function linkaxes(tn::AbstractTensorNetwork, edge::Pair)
111+
return linkaxes(tn, edgetype(tn)(edge))
112+
end
113+
function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
114+
return axes(tn[src(edge)]) axes(tn[dst(edge)])
115+
end
116+
function linknames(tn::AbstractTensorNetwork, edge::Pair)
117+
return linknames(tn, edgetype(tn)(edge))
118+
end
119+
function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge)
120+
return dimnames(tn[src(edge)]) dimnames(tn[dst(edge)])
121+
end
122+
123+
function siteinds(tn::AbstractTensorNetwork, v)
124+
s = nameddimsindices(tn[v])
125+
for v′ in neighbors(tn, v)
126+
s = setdiff(s, nameddimsindices(tn[v′]))
127+
end
128+
return s
129+
end
130+
function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
131+
s = axes(tn[src(edge)]) axes(tn[dst(edge)])
132+
for v′ in neighbors(tn, v)
133+
s = setdiff(s, axes(tn[v′]))
134+
end
135+
return s
136+
end
137+
function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge)
138+
s = dimnames(tn[src(edge)]) dimnames(tn[dst(edge)])
139+
for v′ in neighbors(tn, v)
140+
s = setdiff(s, dimnames(tn[v′]))
141+
end
142+
return s
143+
end
144+
145+
function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex)
146+
vertex_data(tn)[vertex] = value
147+
return tn
148+
end
149+
150+
# TODO: Move to `BaseExtensions` module.
151+
function is_setindex!_expr(expr::Expr)
152+
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
153+
end
154+
is_setindex!_expr(x) = false
155+
is_getindex_expr(expr::Expr) = (expr.head === :ref)
156+
is_getindex_expr(x) = false
157+
is_assignment_expr(expr::Expr) = (expr.head === :(=))
158+
is_assignment_expr(expr) = false
159+
160+
# TODO: Define this in terms of a function mapping
161+
# preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph
162+
# preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph
163+
# Also allow annotating codeblocks like `@views`.
164+
macro preserve_graph(expr)
165+
if !is_setindex!_expr(expr)
166+
error(
167+
"preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)",
168+
)
169+
end
170+
@capture(expr, array_[indices__] = value_)
171+
return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...)))
172+
end
173+
174+
# Update the graph of the TensorNetwork `tn` to include
175+
# edges that should exist based on the tensor connectivity.
176+
function add_missing_edges!(tn::AbstractTensorNetwork)
177+
foreach(v -> add_missing_edges!(tn, v), vertices(tn))
178+
return tn
179+
end
180+
181+
# Update the graph of the TensorNetwork `tn` to include
182+
# edges that should be incident to the vertex `v`
183+
# based on the tensor connectivity.
184+
function add_missing_edges!(tn::AbstractTensorNetwork, v)
185+
for v′ in vertices(tn)
186+
if v v′
187+
e = v => v′
188+
if !isempty(linkinds(tn, e))
189+
add_edge!(tn, e)
190+
end
191+
end
192+
end
193+
return tn
194+
end
195+
196+
# Fix the edges of the TensorNetwork `tn` to match
197+
# the tensor connectivity.
198+
function fix_edges!(tn::AbstractTensorNetwork)
199+
foreach(v -> fix_edges!(tn, v), vertices(tn))
200+
return tn
201+
end
202+
# Fix the edges of the TensorNetwork `tn` to match
203+
# the tensor connectivity at vertex `v`.
204+
function fix_edges!(tn::AbstractTensorNetwork, v)
205+
rem_incident_edges!(tn, v)
206+
rem_edges!(tn, incident_edges(tn, v))
207+
add_missing_edges!(tn, v)
208+
return tn
209+
end
210+
211+
# Customization point.
212+
using NamedDimsArrays: AbstractNamedUnitRange, namedunitrange, nametype, randname
213+
function trivial_unitrange(type::Type{<:AbstractUnitRange})
214+
return Base.oneto(one(eltype(type)))
215+
end
216+
function rand_trivial_namedunitrange(
217+
::Type{<:AbstractNamedUnitRange{<:Any,R,N}}
218+
) where {R,N}
219+
return namedunitrange(trivial_unitrange(R), randname(N))
220+
end
221+
222+
dag(x) = x
223+
224+
using NamedDimsArrays: nameddimsindices
225+
function insert_trivial_link!(tn, e)
226+
add_edge!(tn, e)
227+
l = rand_trivial_namedunitrange(eltype(nameddimsindices(tn[src(e)])))
228+
x = similar(tn[src(e)], (l,))
229+
x[1] = 1
230+
@preserve_graph tn[src(e)] = tn[src(e)] * x
231+
@preserve_graph tn[dst(e)] = tn[dst(e)] * dag(x)
232+
return tn
233+
end
234+
235+
function Base.setindex!(tn::AbstractTensorNetwork, value, v)
236+
@preserve_graph tn[v] = value
237+
fix_edges!(tn, v)
238+
return tn
239+
end
240+
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
241+
# Fix ambiguity error.
242+
function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger)
243+
graph[vertices(graph)[vertex]] = value
244+
return graph
245+
end
246+
# Fix ambiguity error.
247+
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge)
248+
return error("No edge data.")
249+
end
250+
# Fix ambiguity error.
251+
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair)
252+
return error("No edge data.")
253+
end
254+
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
255+
# Fix ambiguity error.
256+
function Base.setindex!(
257+
tn::AbstractTensorNetwork,
258+
value,
259+
edge::Pair{<:OrdinalSuffixedInteger,<:OrdinalSuffixedInteger},
260+
)
261+
return error("No edge data.")
262+
end
263+
264+
function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
265+
println(io, "$(typeof(graph)) with $(nv(graph)) vertices:")
266+
show(io, mime, vertices(graph))
267+
println(io, "\n")
268+
println(io, "and $(ne(graph)) edge(s):")
269+
for e in edges(graph)
270+
show(io, mime, e)
271+
println(io)
272+
end
273+
println(io)
274+
println(io, "with vertex data:")
275+
show(io, mime, axes.(vertex_data(graph)))
276+
return nothing
277+
end
278+
279+
Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph)

src/tensornetwork.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph
2+
using Dictionaries: AbstractDictionary, Indices, dictionary
3+
using Graphs: AbstractSimpleGraph
4+
using NamedDimsArrays: AbstractNamedDimsArray, dimnames, nameddimsarray
5+
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype
6+
using NamedGraphs.GraphsExtensions: arranged_edges, vertextype
7+
8+
function _TensorNetwork end
9+
10+
struct TensorNetwork{V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}} <:
11+
AbstractTensorNetwork{V,VD}
12+
underlying_graph::UG
13+
tensors::Tensors
14+
global @inline function _TensorNetwork(
15+
underlying_graph::UG, tensors::Tensors
16+
) where {V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}}
17+
# This assumes the tensor connectivity matches the graph structure.
18+
return new{V,VD,UG,Tensors}(underlying_graph, tensors)
19+
end
20+
end
21+
22+
DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph)
23+
DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors)
24+
function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork})
25+
return fieldtype(type, :underlying_graph)
26+
end
27+
28+
# Determine the graph structure from the tensors.
29+
function TensorNetwork(t::AbstractDictionary)
30+
g = NamedGraph(eachindex(t))
31+
for v1 in vertices(g)
32+
for v2 in vertices(g)
33+
if v1 v2
34+
if !isdisjoint(dimnames(t[v1]), dimnames(t[v2]))
35+
add_edge!(g, v1 => v2)
36+
end
37+
end
38+
end
39+
end
40+
return _TensorNetwork(g, t)
41+
end
42+
function TensorNetwork(tensors::AbstractDict)
43+
return TensorNetwork(Dictionary(tensors))
44+
end
45+
46+
function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary)
47+
tn = TensorNetwork(tensors)
48+
arranged_edges(tn) arranged_edges(graph) ||
49+
error("The edges in the tensors do not match the graph structure.")
50+
for e in setdiff(arranged_edges(graph), arranged_edges(tn))
51+
insert_trivial_link!(tn, e)
52+
end
53+
return tn
54+
end
55+
function TensorNetwork(graph::AbstractGraph, tensors::AbstractDict)
56+
return TensorNetwork(graph, Dictionary(tensors))
57+
end
58+
function TensorNetwork(f, graph::AbstractGraph)
59+
return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph)))
60+
end
61+
62+
function Base.copy(tn::TensorNetwork)
63+
TensorNetwork(copy(underlying_graph(tn)), copy(vertex_data(tn)))
64+
end
65+
TensorNetwork(tn::TensorNetwork) = copy(tn)
66+
TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn)
67+
function TensorNetwork{V}(tn::TensorNetwork) where {V}
68+
g′ = convert_vertextype(V, underlying_graph(tn))
69+
d = vertex_data(tn)
70+
d′ = dictionary(V(k) => d[k] for k in eachindex(d))
71+
return TensorNetwork(g′, d′)
72+
end
73+
74+
NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn
75+
NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn)

0 commit comments

Comments
 (0)