Skip to content

Commit 8f93ec8

Browse files
author
Jack Dunham
committed
Make abstract tensor network interface more generic.
1 parent 0802355 commit 8f93ec8

File tree

1 file changed

+54
-52
lines changed

1 file changed

+54
-52
lines changed

src/abstracttensornetwork.jl

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,23 @@ using LinearAlgebra: LinearAlgebra, factorize
99
using MacroTools: @capture
1010
using NamedDimsArrays: dimnames, inds
1111
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
12-
using NamedGraphs.GraphsExtensions: , directed_graph, incident_edges, rem_edges!,
13-
rename_vertices, vertextype
12+
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
13+
using NamedGraphs.GraphsExtensions:
14+
,
15+
directed_graph,
16+
incident_edges,
17+
rem_edges!,
18+
rename_vertices,
19+
vertextype
1420
using SplitApplyCombine: flatten
21+
using NamedGraphs.SimilarType: similar_type
1522

1623
abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end
1724

18-
function Graphs.rem_edge!(tn::AbstractTensorNetwork, e)
19-
rem_edge!(underlying_graph(tn), e)
20-
return tn
21-
end
25+
# Need to be careful about removing edges from tensor networks in case there is a bond
26+
Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented()
2227

23-
# TODO: Define a generic fallback for `AbstractDataGraph`?
24-
DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data")
28+
DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented()
2529

2630
# Graphs.jl overloads
2731
function Graphs.weights(graph::AbstractTensorNetwork)
@@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork)
3640
end
3741

3842
# Copy
39-
Base.copy(tn::AbstractTensorNetwork) = error("Not implemented")
43+
Base.copy(::AbstractTensorNetwork) = not_implemented()
4044

4145
# Iteration
4246
Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...)
@@ -49,20 +53,11 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn))
4953
# Overload if needed
5054
Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false
5155

52-
# Derived interface, may need to be overloaded
53-
function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork})
54-
return underlying_graph_type(data_graph_type(G))
55-
end
56-
5756
# AbstractDataGraphs overloads
58-
function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...)
59-
return error("Not implemented")
60-
end
61-
function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...)
62-
return error("Not implemented")
63-
end
57+
DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented()
58+
DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented()
6459

65-
DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented")
60+
DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented()
6661
function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork)
6762
return NamedGraphs.vertex_positions(underlying_graph(tn))
6863
end
@@ -81,49 +76,46 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork)
8176
return map_vertex_data_preserve_graph(adapt(to), tn)
8277
end
8378

84-
function linkinds(tn::AbstractTensorNetwork, edge::Pair)
85-
return linkinds(tn, edgetype(tn)(edge))
86-
end
87-
function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge)
88-
return inds(tn[src(edge)]) inds(tn[dst(edge)])
89-
end
90-
function linkaxes(tn::AbstractTensorNetwork, edge::Pair)
79+
linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge))
80+
linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) inds(tn[dst(edge)])
81+
82+
function linkaxes(tn::AbstractGraph, edge::Pair)
9183
return linkaxes(tn, edgetype(tn)(edge))
9284
end
93-
function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
85+
function linkaxes(tn::AbstractGraph, edge::AbstractEdge)
9486
return axes(tn[src(edge)]) axes(tn[dst(edge)])
9587
end
96-
function linknames(tn::AbstractTensorNetwork, edge::Pair)
88+
function linknames(tn::AbstractGraph, edge::Pair)
9789
return linknames(tn, edgetype(tn)(edge))
9890
end
99-
function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge)
91+
function linknames(tn::AbstractGraph, edge::AbstractEdge)
10092
return dimnames(tn[src(edge)]) dimnames(tn[dst(edge)])
10193
end
10294

103-
function siteinds(tn::AbstractTensorNetwork, v)
95+
function siteinds(tn::AbstractGraph, v)
10496
s = inds(tn[v])
10597
for v′ in neighbors(tn, v)
10698
s = setdiff(s, inds(tn[v′]))
10799
end
108100
return s
109101
end
110-
function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
102+
function siteaxes(tn::AbstractGraph, edge::AbstractEdge)
111103
s = axes(tn[src(edge)]) axes(tn[dst(edge)])
112104
for v′ in neighbors(tn, v)
113105
s = setdiff(s, axes(tn[v′]))
114106
end
115107
return s
116108
end
117-
function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge)
109+
function sitenames(tn::AbstractGraph, edge::AbstractEdge)
118110
s = dimnames(tn[src(edge)]) dimnames(tn[dst(edge)])
119111
for v′ in neighbors(tn, v)
120112
s = setdiff(s, dimnames(tn[v′]))
121113
end
122114
return s
123115
end
124116

125-
function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex)
126-
vertex_data(tn)[vertex] = value
117+
function setindex_preserve_graph!(tn::AbstractGraph, value, vertex)
118+
set!(vertex_data(tn), vertex, value)
127119
return tn
128120
end
129121

@@ -153,15 +145,15 @@ end
153145

154146
# Update the graph of the TensorNetwork `tn` to include
155147
# edges that should exist based on the tensor connectivity.
156-
function add_missing_edges!(tn::AbstractTensorNetwork)
148+
function add_missing_edges!(tn::AbstractGraph)
157149
foreach(v -> add_missing_edges!(tn, v), vertices(tn))
158150
return tn
159151
end
160152

161153
# Update the graph of the TensorNetwork `tn` to include
162154
# edges that should be incident to the vertex `v`
163155
# based on the tensor connectivity.
164-
function add_missing_edges!(tn::AbstractTensorNetwork, v)
156+
function add_missing_edges!(tn::AbstractGraph, v)
165157
for v′ in vertices(tn)
166158
if v v′
167159
e = v => v′
@@ -175,13 +167,13 @@ end
175167

176168
# Fix the edges of the TensorNetwork `tn` to match
177169
# the tensor connectivity.
178-
function fix_edges!(tn::AbstractTensorNetwork)
170+
function fix_edges!(tn::AbstractGraph)
179171
foreach(v -> fix_edges!(tn, v), vertices(tn))
180172
return tn
181173
end
182174
# Fix the edges of the TensorNetwork `tn` to match
183175
# the tensor connectivity at vertex `v`.
184-
function fix_edges!(tn::AbstractTensorNetwork, v)
176+
function fix_edges!(tn::AbstractGraph, v)
185177
rem_edges!(tn, incident_edges(tn, v))
186178
add_missing_edges!(tn, v)
187179
return tn
@@ -215,28 +207,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v)
215207
fix_edges!(tn, v)
216208
return tn
217209
end
218-
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
219210
# Fix ambiguity error.
220211
function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger)
221212
graph[vertices(graph)[vertex]] = value
222213
return graph
223214
end
224-
# Fix ambiguity error.
225-
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge)
226-
return error("No edge data.")
227-
end
228-
# Fix ambiguity error.
229-
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair)
230-
return error("No edge data.")
231-
end
232-
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
215+
Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented()
216+
Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented()
233217
# Fix ambiguity error.
234218
function Base.setindex!(
235219
tn::AbstractTensorNetwork,
236220
value,
237221
edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger},
238222
)
239-
return error("No edge data.")
223+
return not_implemented()
240224
end
241225

242226
function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
@@ -254,4 +238,22 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
254238
return nothing
255239
end
256240

257-
Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph)
241+
Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph)
242+
243+
function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int}
244+
return tensornetwork_induced_subgraph(graph, subvertices)
245+
end
246+
function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices)
247+
return tensornetwork_induced_subgraph(graph, subvertices)
248+
end
249+
250+
function tensornetwork_induced_subgraph(graph, subvertices)
251+
underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices)
252+
subgraph = similar_type(graph)(underlying_subgraph)
253+
for v in vertices(subgraph)
254+
if isassigned(graph, v)
255+
set!(vertex_data(subgraph), v, graph[v])
256+
end
257+
end
258+
return subgraph, vlist
259+
end

0 commit comments

Comments
 (0)