Skip to content

Commit 0802355

Browse files
author
Jack Dunham
committed
Add DataGraphsPartitionedGraphsExt glue for TensorNetwork type
Also includes some fixes to the way `TensorNetwork` types are constructed based on index structure.
1 parent 4350008 commit 0802355

File tree

1 file changed

+76
-3
lines changed

1 file changed

+76
-3
lines changed

src/tensornetwork.jl

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
using Combinatorics: combinations
22
using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph
33
using Dictionaries: AbstractDictionary, Indices, dictionary
4-
using Graphs: AbstractSimpleGraph
4+
using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge!
55
using NamedDimsArrays: AbstractNamedDimsArray, dimnames
66
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype
7-
using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype
7+
using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype
8+
using NamedGraphs.PartitionedGraphs:
9+
AbstractPartitionedGraph,
10+
PartitionedGraphs,
11+
departition,
12+
partitioned_vertices,
13+
partitionedgraph,
14+
quotient_graph,
15+
quotient_graph_type
16+
using .LazyNamedDimsArrays: lazy, Mul
17+
using DataGraphs: vertex_data_eltype, vertex_data, edge_data
18+
using DataGraphs.DataGraphsPartitionedGraphsExt
819

920
function _TensorNetwork end
1021

@@ -24,8 +35,14 @@ function _TensorNetwork(graph::AbstractGraph, tensors)
2435
return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors)))
2536
end
2637

38+
function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors}
39+
return _TensorNetwork(graph, Tensors())
40+
end
41+
2742
DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph)
2843
DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors)
44+
DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}()
45+
DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors))
2946
function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork})
3047
return fieldtype(type, :underlying_graph)
3148
end
@@ -70,7 +87,10 @@ function fix_links!(tn::AbstractTensorNetwork)
7087
for e in setdiff(arranged_edges(graph), tn_edges)
7188
insert_trivial_link!(tn, e)
7289
end
73-
return tn
90+
for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure))
91+
insert_trivial_link!(network, edge)
92+
end
93+
return network
7494
end
7595

7696
# Determine the graph structure from the tensors.
@@ -93,3 +113,56 @@ end
93113

94114
NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn
95115
NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn)
116+
117+
Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn))
118+
119+
function Graphs.rem_edge!(tn::TensorNetwork, e)
120+
if !has_edge(underlying_graph(tn), e)
121+
return false
122+
end
123+
if !isempty(linkinds(tn, e))
124+
throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge."))
125+
end
126+
rem_edge!(underlying_graph(tn), e)
127+
return true
128+
end
129+
130+
function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices)
131+
DT = fieldtype(type, :tensors)
132+
empty_dict = DT()
133+
return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict)
134+
end
135+
136+
## PartitionedGraphs
137+
function PartitionedGraphs.quotient_graph(tn::TensorNetwork)
138+
ug = quotient_graph(underlying_graph(tn))
139+
return TensorNetwork(ug, vertex_data(QuotientView(tn)))
140+
end
141+
function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork})
142+
UG = quotient_graph_type(underlying_graph_type(type))
143+
VD = Vector{vertex_data_eltype(type)}
144+
V = vertextype(UG)
145+
return TensorNetwork{V, VD, UG, Dictionary{V, VD}}
146+
end
147+
148+
function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts)
149+
pg = partitionedgraph(underlying_graph(tn), parts)
150+
return TensorNetwork(pg, vertex_data(tn))
151+
end
152+
153+
PartitionedGraphs.departition(tn::TensorNetwork) = tn
154+
function PartitionedGraphs.departition(
155+
tn::TensorNetwork{<:Any, <:Any, <:AbstractPartitionedGraph}
156+
)
157+
return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn))
158+
end
159+
160+
function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data)
161+
return mapreduce(lazy, *, collect(last(data)))
162+
end
163+
164+
function PartitionedGraphs.quotientview(tn::TensorNetwork)
165+
qview = QuotientView(underlying_graph(tn))
166+
tensors = vertex_data(QuotientView(tn))
167+
return TensorNetwork(qview, tensors)
168+
end

0 commit comments

Comments
 (0)