54
54
"""
55
55
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
56
56
57
- Add to graph `g` the edges with source nodes `s` and target nodes `t`.
58
-
57
+ Add to graph `g` the edges with source nodes `s` and target nodes `t`.
59
58
"""
60
59
function add_edges (g:: GNNGraph{<:COO_T} ,
61
60
snew:: AbstractVector{<:Integer} ,
@@ -79,6 +78,25 @@ function add_edges(g::GNNGraph{<:COO_T},
79
78
g. ndata, edata, g. gdata)
80
79
end
81
80
81
+
82
+ """
83
+ add_nodes(g::GNNGraph, n; [ndata])
84
+
85
+ Add `n` new nodes to graph `g`. In the
86
+ new graph, these nodes will have indexes from `g.num_nodes + 1`
87
+ to `g.num_nodes + n`.
88
+ """
89
+ function add_nodes (g:: GNNGraph{<:COO_T} , n:: Integer ; ndata= (;))
90
+ ndata = normalize_graphdata (ndata, default_name= :x , n= n)
91
+ ndata = cat_features (g. ndata, ndata)
92
+
93
+ GNNGraph (g. graph,
94
+ g. num_nodes + n, g. num_edges, g. num_graphs,
95
+ g. graph_indicator,
96
+ ndata, g. edata, g. gdata)
97
+ end
98
+
99
+
82
100
function SparseArrays. blockdiag (g1:: GNNGraph , g2:: GNNGraph )
83
101
nv1, nv2 = g1. num_nodes, g2. num_nodes
84
102
if g1. graph isa COO_T
@@ -117,8 +135,6 @@ function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix)
117
135
O2 A2]
118
136
end
119
137
120
- # ## Cat public interfaces #############
121
-
122
138
"""
123
139
blockdiag(xs::GNNGraph...)
124
140
@@ -133,14 +149,115 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
133
149
end
134
150
135
151
"""
136
- batch(xs ::Vector{<:GNNGraph})
152
+ batch(gs ::Vector{<:GNNGraph})
137
153
138
154
Batch together multiple `GNNGraph`s into a single one
139
155
containing the total number of original nodes and edges.
140
156
141
157
Equivalent to [`SparseArrays.blockdiag`](@ref).
158
+ See also [`Flux.unbatch`](@ref).
159
+
160
+ # Usage
161
+
162
+ ```juliarepl
163
+ julia> g1 = rand_graph(4, 6, ndata=ones(8, 4))
164
+ GNNGraph:
165
+ num_nodes = 4
166
+ num_edges = 6
167
+ num_graphs = 1
168
+ ndata:
169
+ x => (8, 4)
170
+ edata:
171
+ gdata:
172
+
173
+
174
+ julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7))
175
+ GNNGraph:
176
+ num_nodes = 7
177
+ num_edges = 4
178
+ num_graphs = 1
179
+ ndata:
180
+ x => (8, 7)
181
+ edata:
182
+ gdata:
183
+
184
+
185
+ julia> g12 = Flux.batch([g1, g2])
186
+ GNNGraph:
187
+ num_nodes = 11
188
+ num_edges = 10
189
+ num_graphs = 2
190
+ ndata:
191
+ x => (8, 11)
192
+ edata:
193
+ gdata:
194
+
195
+
196
+ julia> g12.ndata.x
197
+ 8×11 Matrix{Float64}:
198
+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
199
+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
200
+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
201
+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
202
+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
203
+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
204
+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
205
+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
206
+ ```
207
+ """
208
+ Flux. batch (gs:: Vector{<:GNNGraph} ) = blockdiag (gs... )
209
+
210
+
142
211
"""
143
- Flux. batch (xs:: Vector{<:GNNGraph} ) = blockdiag (xs... )
212
+ unbatch(g::GNNGraph)
213
+
214
+ Opposite of the [`Flux.batch`](@ref) operation, returns
215
+ an array of the individual graphs batched together in `g`.
216
+
217
+ See also [`Flux.batch`](@ref) and [`getgraph`](@ref).
218
+
219
+ # Usage
220
+
221
+ ```juliarepl
222
+ julia> gbatched = Flux.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)])
223
+ GNNGraph:
224
+ num_nodes = 19
225
+ num_edges = 16
226
+ num_graphs = 3
227
+ ndata:
228
+ edata:
229
+ gdata:
230
+
231
+ julia> Flux.unbatch(gbatched)
232
+ 3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
233
+ GNNGraph:
234
+ num_nodes = 5
235
+ num_edges = 6
236
+ num_graphs = 1
237
+ ndata:
238
+ edata:
239
+ gdata:
240
+
241
+ GNNGraph:
242
+ num_nodes = 10
243
+ num_edges = 8
244
+ num_graphs = 1
245
+ ndata:
246
+ edata:
247
+ gdata:
248
+
249
+ GNNGraph:
250
+ num_nodes = 4
251
+ num_edges = 2
252
+ num_graphs = 1
253
+ ndata:
254
+ edata:
255
+ gdata:
256
+ ```
257
+ """
258
+ function Flux. unbatch (g:: GNNGraph )
259
+ [getgraph (g, i) for i in 1 : g. num_graphs]
260
+ end
144
261
145
262
146
263
"""
0 commit comments