@@ -11,8 +11,8 @@ const ADJMAT_T = AbstractMatrix
11
11
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
12
12
13
13
"""
14
- GNNGraph(data; [graph_type, nf, ef, gf , num_nodes, graph_indicator, dir])
15
- GNNGraph(g::GNNGraph; [nf, ef, gf ])
14
+ GNNGraph(data; [graph_type, ndata, edata, gdata , num_nodes, graph_indicator, dir])
15
+ GNNGraph(g::GNNGraph; [ndata, edata, gdata ])
16
16
17
17
A type representing a graph structure and storing also arrays
18
18
that contain features associated to nodes, edges, and the whole graph.
@@ -50,10 +50,10 @@ from the LightGraphs' graph library can be used on it.
50
50
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
51
51
Possible values are `:out` and `:in`. Default `:out`.
52
52
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
53
- - `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
54
- - `nf `: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing` .
55
- - `ef `: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing` .
56
- - `gf `: Global features. Default `nothing` .
53
+ - `graph_indicator`. For batched graphs, a vector containing the graph assigment of each node. Default `nothing`.
54
+ - `ndata `: Node features. A named tuple of arrays whose last dimension has size num_nodes.
55
+ - `edata `: Edge features. A named tuple of arrays whose whose last dimension has size num_edges.
56
+ - `gdata `: Global features. A named tuple of arrays whose has size num_graphs .
57
57
58
58
# Usage.
59
59
@@ -77,7 +77,7 @@ g = GNNGraph(s, t)
77
77
g = GNNGraph(erdos_renyi(100, 20))
78
78
79
79
# Copy graph while also adding node features
80
- g = GNNGraph(g, nf= rand(100, 5 ))
80
+ g = GNNGraph(g, ndata = (x = rand(100, g.num_nodes), ))
81
81
82
82
# Send to gpu
83
83
g = g |> gpu
@@ -86,38 +86,28 @@ g = g |> gpu
86
86
# Both source and target are vectors of length num_edges
87
87
source, target = edge_index(g)
88
88
```
89
-
90
- See also [`graph`](@ref), [`edge_index`](@ref), [`node_feature`](@ref), [`edge_feature`](@ref), and [`global_feature`](@ref)
91
89
"""
92
90
struct GNNGraph{T<: Union{COO_T,ADJMAT_T} }
93
91
graph:: T
94
92
num_nodes:: Int
95
93
num_edges:: Int
96
94
num_graphs:: Int
97
95
graph_indicator
98
- nf
99
- ef
100
- gf
101
- # # possible future property stores
102
- # ndata::Dict{String, Any} # https://github.com/FluxML/Zygote.jl/issues/717
103
- # edata::Dict{String, Any}
104
- # gdata::Dict{String, Any}
96
+ ndata:: NamedTuple
97
+ edata:: NamedTuple
98
+ gdata:: NamedTuple
105
99
end
106
100
107
101
@functor GNNGraph
108
102
109
103
function GNNGraph (data;
110
104
num_nodes = nothing ,
111
- num_graphs = 1 ,
112
105
graph_indicator = nothing ,
113
106
graph_type = :coo ,
114
107
dir = :out ,
115
- nf = nothing ,
116
- ef = nothing ,
117
- gf = nothing ,
118
- # ndata = Dict{String, Any}(),
119
- # edata = Dict{String, Any}(),
120
- # gdata = Dict{String, Any}()
108
+ ndata = (;),
109
+ edata = (;),
110
+ gdata = (;),
121
111
)
122
112
123
113
@assert graph_type ∈ [:coo , :dense , :sparse ] " Invalid graph_type $graph_type requested"
@@ -133,18 +123,20 @@ function GNNGraph(data;
133
123
134
124
num_graphs = ! isnothing (graph_indicator) ? maximum (graph_indicator) : 1
135
125
136
- # # Possible future implementation of feature maps.
137
- # # Currently this doesn't play well with zygote due to
138
- # # https://github.com/FluxML/Zygote.jl/issues/717
139
- # ndata["x"] = nf
140
- # edata["e"] = ef
141
- # gdata["g"] = gf
126
+ ndata = normalize_graphdata (ndata, :X )
127
+ edata = normalize_graphdata (edata, :E )
128
+ gdata = normalize_graphdata (gdata, :U )
142
129
143
- GNNGraph (g, num_nodes, num_edges,
144
- num_graphs, graph_indicator,
145
- nf, ef, gf)
130
+ GNNGraph (g,
131
+ num_nodes, num_edges, num_graphs,
132
+ graph_indicator,
133
+ ndata, edata, gdata)
146
134
end
147
135
136
+ normalize_graphdata (data:: NamedTuple , s) = data
137
+ normalize_graphdata (data:: Nothing , s) = NamedTuple ()
138
+ normalize_graphdata (data, s) = NamedTuple {(s,)} ((data,))
139
+
148
140
# COO convenience constructors
149
141
GNNGraph (s:: AbstractVector , t:: AbstractVector , v = nothing ; kws... ) = GNNGraph ((s, t, v); kws... )
150
142
GNNGraph ((s, t):: NTuple{2} ; kws... ) = GNNGraph ((s, t, nothing ); kws... )
@@ -154,14 +146,19 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
154
146
function GNNGraph (g:: AbstractGraph ; kws... )
155
147
s = LightGraphs. src .(LightGraphs. edges (g))
156
148
t = LightGraphs. dst .(LightGraphs. edges (g))
157
- GNNGraph ((s, t); num_nodes = nv (g), kws... )
149
+ GNNGraph ((s, t); num_nodes = LightGraphs . nv (g), kws... )
158
150
end
159
151
160
- function GNNGraph (g:: GNNGraph ;
161
- nf= node_feature (g), ef= edge_feature (g), gf= global_feature (g))
162
- # ndata=copy(g.ndata), edata=copy(g.edata), gdata=copy(g.gdata), # copy keeps the refs to old data
152
+ function GNNGraph (g:: GNNGraph ; ndata= g. ndata, edata= g. edata, gdata= g. gdata)
153
+
154
+ ndata = normalize_graphdata (ndata, :X )
155
+ edata = normalize_graphdata (edata, :E )
156
+ gdata = normalize_graphdata (gdata, :U )
163
157
164
- GNNGraph (g. graph, g. num_nodes, g. num_edges, g. num_graphs, g. graph_indicator, nf, ef, gf) # ndata, edata, gdata,
158
+ GNNGraph (g. graph,
159
+ g. num_nodes, g. num_edges, g. num_graphs,
160
+ g. graph_indicator,
161
+ ndata, edata, gdata)
165
162
end
166
163
167
164
@@ -266,44 +263,6 @@ function LightGraphs.degree(g::GNNGraph{<:ADJMAT_T}, T=Int; dir=:out)
266
263
return dir == :out ? vec (sum (A, dims= 2 )) : vec (sum (A, dims= 1 ))
267
264
end
268
265
269
- # node_feature(g::GNNGraph) = g.ndata["x"]
270
- # edge_feature(g::GNNGraph) = g.edata["e"]
271
- # global_feature(g::GNNGraph) = g.gdata["g"]
272
-
273
-
274
- """
275
- node_feature(g::GNNGraph)
276
-
277
- Return the node features of `g`.
278
- """
279
- node_feature (g:: GNNGraph ) = g. nf
280
-
281
- """
282
- edge_feature(g::GNNGraph)
283
-
284
- Return the edge features of `g`.
285
- """
286
- edge_feature (g:: GNNGraph ) = g. ef
287
-
288
- """
289
- global_feature(g::GNNGraph)
290
-
291
- Return the global features of `g`.
292
- """
293
- global_feature (g:: GNNGraph ) = g. gf
294
-
295
- # function Base.getproperty(g::GNNGraph, sym::Symbol)
296
- # if sym === :nf
297
- # return g.ndata["x"]
298
- # elseif sym === :ef
299
- # return g.edata["e"]
300
- # elseif sym === :gf
301
- # return g.gdata["g"]
302
- # else # fallback to getfield
303
- # return getfield(g, sym)
304
- # end
305
- # end
306
-
307
266
function LightGraphs. laplacian_matrix (g:: GNNGraph , T:: DataType = Int; dir:: Symbol = :out )
308
267
A = adjacency_matrix (g, T; dir= dir)
309
268
D = Diagonal (vec (sum (A; dims= 2 )))
@@ -376,41 +335,44 @@ self-loops will obtain a second self-loop.
376
335
"""
377
336
function add_self_loops (g:: GNNGraph{<:COO_T} )
378
337
s, t = edge_index (g)
379
- @assert edge_feature (g) === nothing
338
+ @assert g . edata === (;)
380
339
@assert edge_weight (g) === nothing
381
340
n = g. num_nodes
382
341
nodes = convert (typeof (s), [1 : n;])
383
342
s = [s; nodes]
384
343
t = [t; nodes]
385
344
386
- GNNGraph ((s, t, nothing ), g. num_nodes, length (s),
387
- g. num_graphs, g. graph_indicator,
388
- node_feature (g), edge_feature (g), global_feature (g))
345
+ GNNGraph ((s, t, nothing ),
346
+ g. num_nodes, length (s), g. num_graphs,
347
+ g. graph_indicator,
348
+ g. ndata, g. edata, g. gdata)
389
349
end
390
350
391
- function add_self_loops (g:: GNNGraph{<:ADJMAT_T} ; add_to_existing = true )
392
- A = graph (g)
393
- @assert edge_feature (g) === nothing
351
+ function add_self_loops (g:: GNNGraph{<:ADJMAT_T} )
352
+ A = adjaceny_matrix (g)
353
+ @assert g . edata === (;)
394
354
A += I
395
355
num_edges = g. num_edges + g. num_nodes
396
- GNNGraph (A, g. num_nodes, num_edges,
397
- g. num_graphs, g. graph_indicator,
398
- node_feature (g), edge_feature (g), global_feature (g))
356
+ GNNGraph (A,
357
+ g. num_nodes, num_edges, g. num_graphs,
358
+ g. graph_indicator,
359
+ g. ndata, g. edata, g. gdata)
399
360
end
400
361
401
362
function remove_self_loops (g:: GNNGraph{<:COO_T} )
402
363
s, t = edge_index (g)
403
364
# TODO remove these constraints
404
- @assert edge_feature (g) === nothing
365
+ @assert g . edata === (;)
405
366
@assert edge_weight (g) === nothing
406
367
407
368
mask_old_loops = s .!= t
408
369
s = s[mask_old_loops]
409
370
t = t[mask_old_loops]
410
371
411
- GNNGraph ((s, t, nothing ), g. num_nodes, length (s),
412
- g. num_graphs, g. graph_indicator,
413
- node_feature (g), edge_feature (g), global_feature (g))
372
+ GNNGraph ((s, t, nothing ),
373
+ g. num_nodes, length (s), g. num_graphs,
374
+ g. graph_indicator,
375
+ g. ndata, g. edata, g. gdata)
414
376
end
415
377
416
378
function _catgraphs (g1:: GNNGraph{<:COO_T} , g2:: GNNGraph{<:COO_T} )
@@ -425,14 +387,12 @@ function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
425
387
ind2 = isnothing (g2. graph_indicator) ? fill! (similar (s2, Int, nv2), 1 ) : g2. graph_indicator
426
388
graph_indicator = vcat (ind1, g1. num_graphs .+ ind2)
427
389
428
- GNNGraph (
429
- (s, t, w),
430
- nv1 + nv2, g1. num_edges + g2. num_edges,
431
- g1. num_graphs + g2. num_graphs, graph_indicator,
432
- cat_features (node_feature (g1), node_feature (g2)),
433
- cat_features (edge_feature (g1), edge_feature (g2)),
434
- cat_features (global_feature (g1), global_feature (g2)),
435
- )
390
+ GNNGraph ((s, t, w),
391
+ nv1 + nv2, g1. num_edges + g2. num_edges, g1. num_graphs + g2. num_graphs,
392
+ graph_indicator,
393
+ cat_features (g1. ndata, g2. ndata),
394
+ cat_features (g1. edata, g2. edata),
395
+ cat_features (g1. gdata, g2. gdata))
436
396
end
437
397
438
398
# ## Cat public interfaces #############
@@ -490,9 +450,9 @@ function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
490
450
s = [nodemap[i] for i in s[edge_mask]]
491
451
t = [nodemap[i] for i in t[edge_mask]]
492
452
w = isnothing (w) ? nothing : w[edge_mask]
493
- nf = isnothing (g. nf) ? nothing : g . nf[:, node_mask]
494
- ef = isnothing (g. ef) ? nothing : g . ef[:, edge_mask]
495
- gf = isnothing (g. gf) ? nothing : g . gf[:,i]
453
+ ndata = getobs (g. ndata, node_mask)
454
+ edata = getobs (g. ndata, edge_mask)
455
+ gdata = getobs (g. gdata, i)
496
456
497
457
num_nodes = length (graph_indicator)
498
458
num_edges = length (s)
@@ -501,10 +461,43 @@ function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
501
461
gnew = GNNGraph ((s,t,w),
502
462
num_nodes, num_edges, num_graphs,
503
463
graph_indicator,
504
- nf, ef, gf )
464
+ ndata, edata, gdata )
505
465
return gnew, nodes
506
466
end
507
467
468
+ # ## TO DEPRECATE ?? ###
469
+ function node_features (g:: GNNGraph )
470
+ if isempty (g. ndata)
471
+ return nothing
472
+ elseif length (g. ndata) > 1
473
+ @error " Multiple feature arrays, access directly with g.ndata.X"
474
+ else
475
+ return g. ndata[1 ]
476
+ end
477
+ end
478
+
479
+ function edge_features (g:: GNNGraph )
480
+ if isempty (g. edata)
481
+ return nothing
482
+ elseif length (g. edata) > 1
483
+ @error " Multiple feature arrays, access directly with g.edata.E"
484
+ else
485
+ return g. edata[1 ]
486
+ end
487
+ end
488
+
489
+ function global_features (g:: GNNGraph )
490
+ if isempty (g. gdata)
491
+ return nothing
492
+ elseif length (g. gdata) > 1
493
+ @error " Multiple feature arrays, access directly with g.gdata.U"
494
+ else
495
+ return g. gdata[1 ]
496
+ end
497
+ end
498
+ # ########
499
+
500
+
508
501
@non_differentiable normalized_laplacian (x... )
509
502
@non_differentiable normalized_adjacency (x... )
510
503
@non_differentiable scaled_laplacian (x... )
0 commit comments