@@ -9,19 +9,23 @@ using LinearAlgebra: LinearAlgebra, factorize
99using MacroTools: @capture
1010using NamedDimsArrays: dimnames, inds
1111using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
12- using NamedGraphs. GraphsExtensions: ⊔ , directed_graph, incident_edges, rem_edges!,
13- rename_vertices, vertextype
12+ using NamedGraphs. OrdinalIndexing: OrdinalSuffixedInteger
13+ using NamedGraphs. GraphsExtensions:
14+ ⊔ ,
15+ directed_graph,
16+ incident_edges,
17+ rem_edges!,
18+ rename_vertices,
19+ vertextype
1420using SplitApplyCombine: flatten
21+ using NamedGraphs. SimilarType: similar_type
1522
1623abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end
1724
18- function Graphs. rem_edge! (tn:: AbstractTensorNetwork , e)
19- rem_edge! (underlying_graph (tn), e)
20- return tn
21- end
25+ # Need to be careful about removing edges from tensor networks in case there is a bond
26+ Graphs. rem_edge! (:: AbstractTensorNetwork , edge) = not_implemented ()
2227
23- # TODO : Define a generic fallback for `AbstractDataGraph`?
24- DataGraphs. edge_data_eltype (:: Type{<:AbstractTensorNetwork} ) = error (" No edge data" )
28+ DataGraphs. edge_data_eltype (:: Type{<:AbstractTensorNetwork} ) = not_implemented ()
2529
2630# Graphs.jl overloads
2731function Graphs. weights (graph:: AbstractTensorNetwork )
@@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork)
3640end
3741
3842# Copy
39- Base. copy (tn :: AbstractTensorNetwork ) = error ( " Not implemented " )
43+ Base. copy (:: AbstractTensorNetwork ) = not_implemented ( )
4044
4145# Iteration
4246Base. iterate (tn:: AbstractTensorNetwork , args... ) = iterate (vertex_data (tn), args... )
@@ -49,20 +53,11 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn))
4953# Overload if needed
5054Graphs. is_directed (:: Type{<:AbstractTensorNetwork} ) = false
5155
52- # Derived interface, may need to be overloaded
53- function DataGraphs. underlying_graph_type (G:: Type{<:AbstractTensorNetwork} )
54- return underlying_graph_type (data_graph_type (G))
55- end
56-
5756# AbstractDataGraphs overloads
58- function DataGraphs. vertex_data (graph:: AbstractTensorNetwork , args... )
59- return error (" Not implemented" )
60- end
61- function DataGraphs. edge_data (graph:: AbstractTensorNetwork , args... )
62- return error (" Not implemented" )
63- end
57+ DataGraphs. vertex_data (:: AbstractTensorNetwork ) = not_implemented ()
58+ DataGraphs. edge_data (:: AbstractTensorNetwork ) = not_implemented ()
6459
65- DataGraphs. underlying_graph (tn :: AbstractTensorNetwork ) = error ( " Not implemented " )
60+ DataGraphs. underlying_graph (:: AbstractTensorNetwork ) = not_implemented ( )
6661function NamedGraphs. vertex_positions (tn:: AbstractTensorNetwork )
6762 return NamedGraphs. vertex_positions (underlying_graph (tn))
6863end
@@ -81,49 +76,46 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork)
8176 return map_vertex_data_preserve_graph (adapt (to), tn)
8277end
8378
84- function linkinds (tn:: AbstractTensorNetwork , edge:: Pair )
85- return linkinds (tn, edgetype (tn)(edge))
86- end
87- function linkinds (tn:: AbstractTensorNetwork , edge:: AbstractEdge )
88- return inds (tn[src (edge)]) ∩ inds (tn[dst (edge)])
89- end
90- function linkaxes (tn:: AbstractTensorNetwork , edge:: Pair )
79+ linkinds (tn:: AbstractGraph , edge:: Pair ) = linkinds (tn, edgetype (tn)(edge))
80+ linkinds (tn:: AbstractGraph , edge:: AbstractEdge ) = inds (tn[src (edge)]) ∩ inds (tn[dst (edge)])
81+
82+ function linkaxes (tn:: AbstractGraph , edge:: Pair )
9183 return linkaxes (tn, edgetype (tn)(edge))
9284end
93- function linkaxes (tn:: AbstractTensorNetwork , edge:: AbstractEdge )
85+ function linkaxes (tn:: AbstractGraph , edge:: AbstractEdge )
9486 return axes (tn[src (edge)]) ∩ axes (tn[dst (edge)])
9587end
96- function linknames (tn:: AbstractTensorNetwork , edge:: Pair )
88+ function linknames (tn:: AbstractGraph , edge:: Pair )
9789 return linknames (tn, edgetype (tn)(edge))
9890end
99- function linknames (tn:: AbstractTensorNetwork , edge:: AbstractEdge )
91+ function linknames (tn:: AbstractGraph , edge:: AbstractEdge )
10092 return dimnames (tn[src (edge)]) ∩ dimnames (tn[dst (edge)])
10193end
10294
103- function siteinds (tn:: AbstractTensorNetwork , v)
95+ function siteinds (tn:: AbstractGraph , v)
10496 s = inds (tn[v])
10597 for v′ in neighbors (tn, v)
10698 s = setdiff (s, inds (tn[v′]))
10799 end
108100 return s
109101end
110- function siteaxes (tn:: AbstractTensorNetwork , edge:: AbstractEdge )
102+ function siteaxes (tn:: AbstractGraph , edge:: AbstractEdge )
111103 s = axes (tn[src (edge)]) ∩ axes (tn[dst (edge)])
112104 for v′ in neighbors (tn, v)
113105 s = setdiff (s, axes (tn[v′]))
114106 end
115107 return s
116108end
117- function sitenames (tn:: AbstractTensorNetwork , edge:: AbstractEdge )
109+ function sitenames (tn:: AbstractGraph , edge:: AbstractEdge )
118110 s = dimnames (tn[src (edge)]) ∩ dimnames (tn[dst (edge)])
119111 for v′ in neighbors (tn, v)
120112 s = setdiff (s, dimnames (tn[v′]))
121113 end
122114 return s
123115end
124116
125- function setindex_preserve_graph! (tn:: AbstractTensorNetwork , value, vertex)
126- vertex_data (tn)[ vertex] = value
117+ function setindex_preserve_graph! (tn:: AbstractGraph , value, vertex)
118+ set! ( vertex_data (tn), vertex, value)
127119 return tn
128120end
129121
@@ -153,15 +145,15 @@ end
153145
154146# Update the graph of the TensorNetwork `tn` to include
155147# edges that should exist based on the tensor connectivity.
156- function add_missing_edges! (tn:: AbstractTensorNetwork )
148+ function add_missing_edges! (tn:: AbstractGraph )
157149 foreach (v -> add_missing_edges! (tn, v), vertices (tn))
158150 return tn
159151end
160152
161153# Update the graph of the TensorNetwork `tn` to include
162154# edges that should be incident to the vertex `v`
163155# based on the tensor connectivity.
164- function add_missing_edges! (tn:: AbstractTensorNetwork , v)
156+ function add_missing_edges! (tn:: AbstractGraph , v)
165157 for v′ in vertices (tn)
166158 if v ≠ v′
167159 e = v => v′
@@ -175,13 +167,13 @@ end
175167
176168# Fix the edges of the TensorNetwork `tn` to match
177169# the tensor connectivity.
178- function fix_edges! (tn:: AbstractTensorNetwork )
170+ function fix_edges! (tn:: AbstractGraph )
179171 foreach (v -> fix_edges! (tn, v), vertices (tn))
180172 return tn
181173end
182174# Fix the edges of the TensorNetwork `tn` to match
183175# the tensor connectivity at vertex `v`.
184- function fix_edges! (tn:: AbstractTensorNetwork , v)
176+ function fix_edges! (tn:: AbstractGraph , v)
185177 rem_edges! (tn, incident_edges (tn, v))
186178 add_missing_edges! (tn, v)
187179 return tn
@@ -215,28 +207,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v)
215207 fix_edges! (tn, v)
216208 return tn
217209end
218- using NamedGraphs. OrdinalIndexing: OrdinalSuffixedInteger
219210# Fix ambiguity error.
220211function Base. setindex! (graph:: AbstractTensorNetwork , value, vertex:: OrdinalSuffixedInteger )
221212 graph[vertices (graph)[vertex]] = value
222213 return graph
223214end
224- # Fix ambiguity error.
225- function Base. setindex! (tn:: AbstractTensorNetwork , value, edge:: AbstractEdge )
226- return error (" No edge data." )
227- end
228- # Fix ambiguity error.
229- function Base. setindex! (tn:: AbstractTensorNetwork , value, edge:: Pair )
230- return error (" No edge data." )
231- end
232- using NamedGraphs. OrdinalIndexing: OrdinalSuffixedInteger
215+ Base. setindex! (tn:: AbstractTensorNetwork , value, edge:: AbstractEdge ) = not_implemented ()
216+ Base. setindex! (tn:: AbstractTensorNetwork , value, edge:: Pair ) = not_implemented ()
233217# Fix ambiguity error.
234218function Base. setindex! (
235219 tn:: AbstractTensorNetwork ,
236220 value,
237221 edge:: Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger} ,
238222 )
239- return error ( " No edge data. " )
223+ return not_implemented ( )
240224end
241225
242226function Base. show (io:: IO , mime:: MIME"text/plain" , graph:: AbstractTensorNetwork )
@@ -254,4 +238,22 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
254238 return nothing
255239end
256240
257- Base. show (io:: IO , graph:: AbstractTensorNetwork ) = show (io, MIME " text/plain" (), graph)
241+ Base. show (io:: IO , graph:: AbstractTensorNetwork ) = show (io, MIME " text/plain" (), graph)
242+
243+ function Graphs. induced_subgraph (graph:: AbstractTensorNetwork , subvertices:: AbstractVector{V} ) where {V <: Int }
244+ return tensornetwork_induced_subgraph (graph, subvertices)
245+ end
246+ function Graphs. induced_subgraph (graph:: AbstractTensorNetwork , subvertices)
247+ return tensornetwork_induced_subgraph (graph, subvertices)
248+ end
249+
250+ function tensornetwork_induced_subgraph (graph, subvertices)
251+ underlying_subgraph, vlist = Graphs. induced_subgraph (underlying_graph (graph), subvertices)
252+ subgraph = similar_type (graph)(underlying_subgraph)
253+ for v in vertices (subgraph)
254+ if isassigned (graph, v)
255+ set! (vertex_data (subgraph), v, graph[v])
256+ end
257+ end
258+ return subgraph, vlist
259+ end
0 commit comments