|
| 1 | +using Adapt: Adapt, adapt, adapt_structure |
| 2 | +using BackendSelection: @Algorithm_str, Algorithm |
| 3 | +using DataGraphs: |
| 4 | + DataGraphs, |
| 5 | + AbstractDataGraph, |
| 6 | + edge_data, |
| 7 | + underlying_graph, |
| 8 | + underlying_graph_type, |
| 9 | + vertex_data |
| 10 | +using Dictionaries: Dictionary |
| 11 | +using Graphs: |
| 12 | + Graphs, |
| 13 | + AbstractEdge, |
| 14 | + AbstractGraph, |
| 15 | + Graph, |
| 16 | + add_edge!, |
| 17 | + add_vertex!, |
| 18 | + bfs_tree, |
| 19 | + center, |
| 20 | + dst, |
| 21 | + edges, |
| 22 | + edgetype, |
| 23 | + ne, |
| 24 | + neighbors, |
| 25 | + nv, |
| 26 | + rem_edge!, |
| 27 | + src, |
| 28 | + vertices |
| 29 | +using LinearAlgebra: LinearAlgebra, factorize |
| 30 | +using MacroTools: @capture |
| 31 | +using NamedDimsArrays: dimnames |
| 32 | +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree |
| 33 | +using NamedGraphs.GraphsExtensions: |
| 34 | + ⊔, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype |
| 35 | +using SplitApplyCombine: flatten |
| 36 | + |
| 37 | +abstract type AbstractTensorNetwork{V,VD} <: AbstractDataGraph{V,VD,Nothing} end |
| 38 | + |
| 39 | +function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) |
| 40 | + rem_edge!(underlying_graph(tn), e) |
| 41 | + return tn |
| 42 | +end |
| 43 | + |
| 44 | +# TODO: Define a generic fallback for `AbstractDataGraph`? |
| 45 | +DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") |
| 46 | + |
| 47 | +# Graphs.jl overloads |
| 48 | +function Graphs.weights(graph::AbstractTensorNetwork) |
| 49 | + V = vertextype(graph) |
| 50 | + es = Tuple.(edges(graph)) |
| 51 | + ws = Dictionary{Tuple{V,V},Float64}(es, undef) |
| 52 | + for e in edges(graph) |
| 53 | + w = log2(dim(commoninds(graph, e))) |
| 54 | + ws[(src(e), dst(e))] = w |
| 55 | + end |
| 56 | + return ws |
| 57 | +end |
| 58 | + |
| 59 | +# Copy |
| 60 | +Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") |
| 61 | + |
| 62 | +# Iteration |
| 63 | +Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) |
| 64 | + |
| 65 | +# TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition, |
| 66 | +# where it is defined as the `vertextype`. Does that cause problems or should it be changed? |
| 67 | +Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) |
| 68 | + |
| 69 | +# Overload if needed |
| 70 | +Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false |
| 71 | + |
| 72 | +# Derived interface, may need to be overloaded |
| 73 | +function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) |
| 74 | + return underlying_graph_type(data_graph_type(G)) |
| 75 | +end |
| 76 | + |
| 77 | +# AbstractDataGraphs overloads |
| 78 | +function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) |
| 79 | + return error("Not implemented") |
| 80 | +end |
| 81 | +function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) |
| 82 | + return error("Not implemented") |
| 83 | +end |
| 84 | + |
| 85 | +DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") |
| 86 | +function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) |
| 87 | + return NamedGraphs.vertex_positions(underlying_graph(tn)) |
| 88 | +end |
| 89 | +function NamedGraphs.ordered_vertices(tn::AbstractTensorNetwork) |
| 90 | + return NamedGraphs.ordered_vertices(underlying_graph(tn)) |
| 91 | +end |
| 92 | + |
| 93 | +function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) |
| 94 | + # TODO: Define and use: |
| 95 | + # |
| 96 | + # @preserve_graph map_vertex_data(adapt(to), tn) |
| 97 | + # |
| 98 | + # or just: |
| 99 | + # |
| 100 | + # @preserve_graph map(adapt(to), tn) |
| 101 | + return map_vertex_data_preserve_graph(adapt(to), tn) |
| 102 | +end |
| 103 | + |
| 104 | +function linkinds(tn::AbstractTensorNetwork, edge::Pair) |
| 105 | + return linkinds(tn, edgetype(tn)(edge)) |
| 106 | +end |
| 107 | +function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) |
| 108 | + return nameddimsindices(tn[src(edge)]) ∩ nameddimsindices(tn[dst(edge)]) |
| 109 | +end |
| 110 | +function linkaxes(tn::AbstractTensorNetwork, edge::Pair) |
| 111 | + return linkaxes(tn, edgetype(tn)(edge)) |
| 112 | +end |
| 113 | +function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) |
| 114 | + return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) |
| 115 | +end |
| 116 | +function linknames(tn::AbstractTensorNetwork, edge::Pair) |
| 117 | + return linknames(tn, edgetype(tn)(edge)) |
| 118 | +end |
| 119 | +function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) |
| 120 | + return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) |
| 121 | +end |
| 122 | + |
| 123 | +function siteinds(tn::AbstractTensorNetwork, v) |
| 124 | + s = nameddimsindices(tn[v]) |
| 125 | + for v′ in neighbors(tn, v) |
| 126 | + s = setdiff(s, nameddimsindices(tn[v′])) |
| 127 | + end |
| 128 | + return s |
| 129 | +end |
| 130 | +function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) |
| 131 | + s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) |
| 132 | + for v′ in neighbors(tn, v) |
| 133 | + s = setdiff(s, axes(tn[v′])) |
| 134 | + end |
| 135 | + return s |
| 136 | +end |
| 137 | +function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) |
| 138 | + s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) |
| 139 | + for v′ in neighbors(tn, v) |
| 140 | + s = setdiff(s, dimnames(tn[v′])) |
| 141 | + end |
| 142 | + return s |
| 143 | +end |
| 144 | + |
| 145 | +function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) |
| 146 | + vertex_data(tn)[vertex] = value |
| 147 | + return tn |
| 148 | +end |
| 149 | + |
| 150 | +# TODO: Move to `BaseExtensions` module. |
| 151 | +function is_setindex!_expr(expr::Expr) |
| 152 | + return is_assignment_expr(expr) && is_getindex_expr(first(expr.args)) |
| 153 | +end |
| 154 | +is_setindex!_expr(x) = false |
| 155 | +is_getindex_expr(expr::Expr) = (expr.head === :ref) |
| 156 | +is_getindex_expr(x) = false |
| 157 | +is_assignment_expr(expr::Expr) = (expr.head === :(=)) |
| 158 | +is_assignment_expr(expr) = false |
| 159 | + |
| 160 | +# TODO: Define this in terms of a function mapping |
| 161 | +# preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph |
| 162 | +# preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph |
| 163 | +# Also allow annotating codeblocks like `@views`. |
| 164 | +macro preserve_graph(expr) |
| 165 | + if !is_setindex!_expr(expr) |
| 166 | + error( |
| 167 | + "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", |
| 168 | + ) |
| 169 | + end |
| 170 | + @capture(expr, array_[indices__] = value_) |
| 171 | + return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...))) |
| 172 | +end |
| 173 | + |
| 174 | +# Update the graph of the TensorNetwork `tn` to include |
| 175 | +# edges that should exist based on the tensor connectivity. |
| 176 | +function add_missing_edges!(tn::AbstractTensorNetwork) |
| 177 | + foreach(v -> add_missing_edges!(tn, v), vertices(tn)) |
| 178 | + return tn |
| 179 | +end |
| 180 | + |
| 181 | +# Update the graph of the TensorNetwork `tn` to include |
| 182 | +# edges that should be incident to the vertex `v` |
| 183 | +# based on the tensor connectivity. |
| 184 | +function add_missing_edges!(tn::AbstractTensorNetwork, v) |
| 185 | + for v′ in vertices(tn) |
| 186 | + if v ≠ v′ |
| 187 | + e = v => v′ |
| 188 | + if !isempty(linkinds(tn, e)) |
| 189 | + add_edge!(tn, e) |
| 190 | + end |
| 191 | + end |
| 192 | + end |
| 193 | + return tn |
| 194 | +end |
| 195 | + |
| 196 | +# Fix the edges of the TensorNetwork `tn` to match |
| 197 | +# the tensor connectivity. |
| 198 | +function fix_edges!(tn::AbstractTensorNetwork) |
| 199 | + foreach(v -> fix_edges!(tn, v), vertices(tn)) |
| 200 | + return tn |
| 201 | +end |
| 202 | +# Fix the edges of the TensorNetwork `tn` to match |
| 203 | +# the tensor connectivity at vertex `v`. |
| 204 | +function fix_edges!(tn::AbstractTensorNetwork, v) |
| 205 | + rem_incident_edges!(tn, v) |
| 206 | + rem_edges!(tn, incident_edges(tn, v)) |
| 207 | + add_missing_edges!(tn, v) |
| 208 | + return tn |
| 209 | +end |
| 210 | + |
| 211 | +# Customization point. |
| 212 | +using NamedDimsArrays: AbstractNamedUnitRange, namedunitrange, nametype, randname |
| 213 | +function trivial_unitrange(type::Type{<:AbstractUnitRange}) |
| 214 | + return Base.oneto(one(eltype(type))) |
| 215 | +end |
| 216 | +function rand_trivial_namedunitrange( |
| 217 | + ::Type{<:AbstractNamedUnitRange{<:Any,R,N}} |
| 218 | +) where {R,N} |
| 219 | + return namedunitrange(trivial_unitrange(R), randname(N)) |
| 220 | +end |
| 221 | + |
| 222 | +dag(x) = x |
| 223 | + |
| 224 | +using NamedDimsArrays: nameddimsindices |
| 225 | +function insert_trivial_link!(tn, e) |
| 226 | + add_edge!(tn, e) |
| 227 | + l = rand_trivial_namedunitrange(eltype(nameddimsindices(tn[src(e)]))) |
| 228 | + x = similar(tn[src(e)], (l,)) |
| 229 | + x[1] = 1 |
| 230 | + @preserve_graph tn[src(e)] = tn[src(e)] * x |
| 231 | + @preserve_graph tn[dst(e)] = tn[dst(e)] * dag(x) |
| 232 | + return tn |
| 233 | +end |
| 234 | + |
| 235 | +function Base.setindex!(tn::AbstractTensorNetwork, value, v) |
| 236 | + @preserve_graph tn[v] = value |
| 237 | + fix_edges!(tn, v) |
| 238 | + return tn |
| 239 | +end |
| 240 | +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger |
| 241 | +# Fix ambiguity error. |
| 242 | +function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) |
| 243 | + graph[vertices(graph)[vertex]] = value |
| 244 | + return graph |
| 245 | +end |
| 246 | +# Fix ambiguity error. |
| 247 | +function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) |
| 248 | + return error("No edge data.") |
| 249 | +end |
| 250 | +# Fix ambiguity error. |
| 251 | +function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) |
| 252 | + return error("No edge data.") |
| 253 | +end |
| 254 | +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger |
| 255 | +# Fix ambiguity error. |
| 256 | +function Base.setindex!( |
| 257 | + tn::AbstractTensorNetwork, |
| 258 | + value, |
| 259 | + edge::Pair{<:OrdinalSuffixedInteger,<:OrdinalSuffixedInteger}, |
| 260 | +) |
| 261 | + return error("No edge data.") |
| 262 | +end |
| 263 | + |
| 264 | +function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) |
| 265 | + println(io, "$(typeof(graph)) with $(nv(graph)) vertices:") |
| 266 | + show(io, mime, vertices(graph)) |
| 267 | + println(io, "\n") |
| 268 | + println(io, "and $(ne(graph)) edge(s):") |
| 269 | + for e in edges(graph) |
| 270 | + show(io, mime, e) |
| 271 | + println(io) |
| 272 | + end |
| 273 | + println(io) |
| 274 | + println(io, "with vertex data:") |
| 275 | + show(io, mime, axes.(vertex_data(graph))) |
| 276 | + return nothing |
| 277 | +end |
| 278 | + |
| 279 | +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) |
0 commit comments