1+ using Combinatorics: combinations
12using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph
23using Dictionaries: AbstractDictionary, Indices, dictionary
34using Graphs: AbstractSimpleGraph
45using NamedDimsArrays: AbstractNamedDimsArray, dimnames
56using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype
6- using NamedGraphs. GraphsExtensions: arranged_edges, vertextype
7+ using NamedGraphs. GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype
78
89function _TensorNetwork end
910
@@ -18,45 +19,65 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar
1819 return new {V, VD, UG, Tensors} (underlying_graph, tensors)
1920 end
2021end
22+ # This assumes the tensor connectivity matches the graph structure.
23+ function _TensorNetwork (graph:: AbstractGraph , tensors)
24+ return _TensorNetwork (graph, Dictionary (keys (tensors), values (tensors)))
25+ end
2126
2227DataGraphs. underlying_graph (tn:: TensorNetwork ) = getfield (tn, :underlying_graph )
2328DataGraphs. vertex_data (tn:: TensorNetwork ) = getfield (tn, :tensors )
2429function DataGraphs. underlying_graph_type (type:: Type{<:TensorNetwork} )
2530 return fieldtype (type, :underlying_graph )
2631end
2732
28- # Determine the graph structure from the tensors.
29- function TensorNetwork (t:: AbstractDictionary )
30- g = NamedGraph (eachindex (t))
31- for v1 in vertices (g)
32- for v2 in vertices (g)
33- if v1 ≠ v2
34- if ! isdisjoint (dimnames (t[v1]), dimnames (t[v2]))
35- add_edge! (g, v1 => v2)
36- end
33+ # For a collection of tensors, return the edges implied by shared indices
34+ # as a list of `edgetype` edges of keys/vertices.
35+ function tensornetwork_edges (edgetype:: Type , tensors)
36+ # We need to collect the keys since in the case of `tensors::AbstractDictionary`,
37+ # `keys(tensors)::AbstractIndices`, which is indexed by `keys(tensors)` rather
38+ # than `1:length(keys(tensors))`, which is assumed by `combinations`.
39+ verts = collect (keys (tensors))
40+ return filter (
41+ ! isnothing, map (combinations (verts, 2 )) do (v1, v2)
42+ if ! isdisjoint (inds (tensors[v1]), inds (tensors[v2]))
43+ return arrange_edge (edgetype (v1, v2))
3744 end
45+ return nothing
3846 end
39- end
40- return _TensorNetwork (g, t)
47+ )
4148end
42- function TensorNetwork (tensors:: AbstractDict )
43- return TensorNetwork (Dictionary (tensors))
49+ tensornetwork_edges (tensors) = tensornetwork_edges (NamedEdge, tensors)
50+
51+ function TensorNetwork (f:: Base.Callable , graph:: AbstractGraph )
52+ tensors = Dictionary (vertices (graph), f .(vertices (graph)))
53+ return TensorNetwork (graph, tensors)
54+ end
55+ function TensorNetwork (graph:: AbstractGraph , tensors)
56+ tn = _TensorNetwork (graph, tensors)
57+ fix_links! (tn)
58+ return tn
4459end
4560
46- function TensorNetwork (graph:: AbstractGraph , tensors:: AbstractDictionary )
47- tn = TensorNetwork (tensors)
48- arranged_edges (tn) ⊆ arranged_edges (graph) ||
61+ # Insert trivial links for missing edges, and also check
62+ # the vertices and edges are consistent between the graph and tensors.
63+ function fix_links! (tn:: AbstractTensorNetwork )
64+ graph = underlying_graph (tn)
65+ tensors = vertex_data (tn)
66+ @assert issetequal (vertices (graph), keys (tensors)) " Graph vertices and tensor keys must match."
67+ tn_edges = tensornetwork_edges (edgetype (graph), tensors)
68+ tn_edges ⊆ arranged_edges (graph) ||
4969 error (" The edges in the tensors do not match the graph structure." )
50- for e in setdiff (arranged_edges (graph), arranged_edges (tn) )
70+ for e in setdiff (arranged_edges (graph), tn_edges )
5171 insert_trivial_link! (tn, e)
5272 end
5373 return tn
5474end
55- function TensorNetwork (graph:: AbstractGraph , tensors:: AbstractDict )
56- return TensorNetwork (graph, Dictionary (tensors))
57- end
58- function TensorNetwork (f, graph:: AbstractGraph )
59- return TensorNetwork (graph, Dict (v => f (v) for v in vertices (graph)))
75+
76+ # Determine the graph structure from the tensors.
77+ function TensorNetwork (tensors)
78+ graph = NamedGraph (keys (tensors))
79+ add_edges! (graph, tensornetwork_edges (tensors))
80+ return _TensorNetwork (graph, tensors)
6081end
6182
6283function Base. copy (tn:: TensorNetwork )
6586TensorNetwork (tn:: TensorNetwork ) = copy (tn)
6687TensorNetwork {V} (tn:: TensorNetwork{V} ) where {V} = copy (tn)
6788function TensorNetwork {V} (tn:: TensorNetwork ) where {V}
68- g′ = convert_vertextype (V, underlying_graph (tn))
69- d = vertex_data (tn)
70- d′ = dictionary (V (k) => d[k] for k in eachindex (d))
71- return TensorNetwork (g′, d′)
89+ g = convert_vertextype (V, underlying_graph (tn))
90+ d = dictionary (V (k) => tn[k] for k in keys (d))
91+ return TensorNetwork (g, d)
7292end
7393
7494NamedGraphs. convert_vertextype (:: Type{V} , tn:: TensorNetwork{V} ) where {V} = tn
0 commit comments