@@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
11
11
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
12
12
13
13
"""
14
- GNNGraph(data; [graph_type, dir, num_nodes, nf, ef, gf ])
14
+ GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir ])
15
15
GNNGraph(g::GNNGraph; [nf, ef, gf])
16
16
17
17
A type representing a graph structure and storing also arrays
@@ -43,11 +43,13 @@ from the LightGraphs' graph library can be used on it.
43
43
- `:dense`. A dense adjacency matrix representation.
44
44
Default `:coo`.
45
45
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
46
- Possible values are `:out` and `:in`. Defaul `:out`.
47
- - `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default nothing.
48
- - `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default nothing.
49
- - `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default nothing.
50
- - `gf`: Global features. Default nothing.
46
+ Possible values are `:out` and `:in`. Default `:out`.
47
+ - `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
48
+ - `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
49
+ - `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
50
+ - `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
51
+ - `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
52
+ - `gf`: Global features. Default `nothing`.
51
53
52
54
# Usage.
53
55
@@ -87,6 +89,8 @@ struct GNNGraph{T<:Union{COO_T,ADJMAT_T}}
87
89
graph:: T
88
90
num_nodes:: Int
89
91
num_edges:: Int
92
+ num_graphs:: Int
93
+ graph_indicator
90
94
nf
91
95
ef
92
96
gf
99
103
@functor GNNGraph
100
104
101
105
function GNNGraph (data;
102
- num_nodes = nothing ,
106
+ num_nodes = nothing ,
107
+ num_graphs = 1 ,
108
+ graph_indicator = nothing ,
103
109
graph_type = :coo ,
104
110
dir = :out ,
105
111
nf = nothing ,
@@ -119,6 +125,9 @@ function GNNGraph(data;
119
125
elseif graph_type == :sparse
120
126
g, num_nodes, num_edges = to_sparse (data; dir)
121
127
end
128
+ if num_graphs > 1
129
+ @assert len (graph_indicator) = num_nodes " When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
130
+ end
122
131
123
132
# # Possible future implementation of feature maps.
124
133
# # Currently this doesn't play well with zygote due to
@@ -127,8 +136,9 @@ function GNNGraph(data;
127
136
# edata["e"] = ef
128
137
# gdata["g"] = gf
129
138
130
-
131
- GNNGraph (g, num_nodes, num_edges, nf, ef, gf)
139
+ GNNGraph (g, num_nodes, num_edges,
140
+ num_graphs, graph_indicator,
141
+ nf, ef, gf)
132
142
end
133
143
134
144
# COO convenience constructors
@@ -147,7 +157,7 @@ function GNNGraph(g::GNNGraph;
147
157
nf= node_feature (g), ef= edge_feature (g), gf= global_feature (g))
148
158
# ndata=copy(g.ndata), edata=copy(g.edata), gdata=copy(g.gdata), # copy keeps the refs to old data
149
159
150
- GNNGraph (g. graph, g. num_nodes, g. num_edges, nf, ef, gf) # ndata, edata, gdata,
160
+ GNNGraph (g. graph, g. num_nodes, g. num_edges, g . num_graphs, g . graph_indicator, nf, ef, gf) # ndata, edata, gdata,
151
161
end
152
162
153
163
@@ -370,6 +380,7 @@ function add_self_loops(g::GNNGraph{<:COO_T})
370
380
t = [t; nodes]
371
381
372
382
GNNGraph ((s, t, nothing ), g. num_nodes, length (s),
383
+ g. num_graphs, g. graph_indicator,
373
384
node_feature (g), edge_feature (g), global_feature (g))
374
385
end
375
386
@@ -379,6 +390,7 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T}; add_to_existing=true)
379
390
A += I
380
391
num_edges = g. num_edges + g. num_nodes
381
392
GNNGraph (A, g. num_nodes, num_edges,
393
+ g. num_graphs, g. graph_indicator,
382
394
node_feature (g), edge_feature (g), global_feature (g))
383
395
end
384
396
@@ -392,10 +404,46 @@ function remove_self_loops(g::GNNGraph{<:COO_T})
392
404
s = s[mask_old_loops]
393
405
t = t[mask_old_loops]
394
406
395
- GNNGraph ((s, t, nothing ), g. num_nodes, length (s),
407
+ GNNGraph ((s, t, nothing ), g. num_nodes, length (s),
408
+ g. num_graphs, g. graph_indicator,
396
409
node_feature (g), edge_feature (g), global_feature (g))
397
410
end
398
411
412
+ function _catgraphs (g1:: GNNGraph{<:COO_T} , g2:: GNNGraph{<:COO_T} )
413
+ s1, t1 = edge_index (g1)
414
+ s2, t2 = edge_index (g2)
415
+ nv1, nv2 = g1. num_nodes, g2. num_nodes
416
+ s = vcat (s1, nv1 .+ s2)
417
+ t = vcat (t1, nv1 .+ t2)
418
+ w = cat_features (edge_weight (g1), edge_weight (g2))
419
+
420
+ ind1 = isnothing (g1. graph_indicator) ? fill! (similar (s1, Int, nv1), 1 ) : g1. graph_indicator
421
+ ind2 = isnothing (g2. graph_indicator) ? fill! (similar (s2, Int, nv2), 1 ) : g2. graph_indicator
422
+ graph_indicator = vcat (ind1, g1. num_graphs .+ ind2)
423
+
424
+ GNNGraph (
425
+ (s, t, w),
426
+ nv1 + nv2, g1. num_edges + g2. num_edges,
427
+ g1. num_graphs + g2. num_graphs, graph_indicator,
428
+ cat_features (node_feature (g1), node_feature (g2)),
429
+ cat_features (edge_feature (g1), edge_feature (g2)),
430
+ cat_features (global_feature (g1), global_feature (g2)),
431
+ )
432
+ end
433
+
434
+ # Cat public interfaces
435
+ function SparseArrays. blockdiag (g1:: GNNGraph , gothers:: GNNGraph... )
436
+ @assert length (gothers) >= 1
437
+ g = g1
438
+ for go in gothers
439
+ g = _catgraphs (g, go)
440
+ end
441
+ return g
442
+ end
443
+
444
+ Flux. batch (xs:: Vector{<:GNNGraph} ) = blockdiag (xs... )
445
+ # ########################
446
+
399
447
@non_differentiable normalized_laplacian (x... )
400
448
@non_differentiable normalized_adjacency (x... )
401
449
@non_differentiable scaled_laplacian (x... )
0 commit comments