Skip to content

Commit 0510d2b

Browse files
ndata, edata, gdata
1 parent cba6565 commit 0510d2b

File tree

13 files changed

+253
-263
lines changed

13 files changed

+253
-263
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1010
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
12+
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
1213
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1516
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1617
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
18+
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
1719
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1820
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1921
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

perf/perf.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function run_single_benchmark(N, c, D, CONV; gtype=:lg)
1111
data = erdos_renyi(N, c / (N-1), seed=17)
1212
X = randn(Float32, D, N)
1313

14-
g = GNNGraph(data; nf=X, graph_type=gtype)
14+
g = GNNGraph(data; ndata=X, graph_type=gtype)
1515
g_gpu = g |> gpu
1616

1717
m = CONV(D => D)

src/GraphNeuralNetworks.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
module GraphNeuralNetworks
22

3-
using Core: apply_type
4-
using NNlib: similar
5-
using LinearAlgebra: similar, fill!
63
using Statistics: mean
74
using LinearAlgebra
85
using SparseArrays
@@ -12,17 +9,17 @@ using CUDA
129
using Flux
1310
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
1411
using MacroTools: @forward
12+
using LearnBase: getobs
1513
using NNlib, NNlibCUDA
1614
using ChainRulesCore
1715
import LightGraphs
18-
using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv,
19-
adjacency_matrix, degree
16+
using LightGraphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
2017

2118
export
2219
# gnngraph
2320
GNNGraph,
2421
edge_index,
25-
node_feature, edge_feature, global_feature,
22+
node_features, edge_features, global_features,
2623
adjacency_list, normalized_laplacian, scaled_laplacian,
2724
add_self_loops, remove_self_loops,
2825
subgraph,
@@ -52,7 +49,6 @@ export
5249
topk_index
5350

5451

55-
5652

5753
include("gnngraph.jl")
5854
include("graph_conversions.jl")

src/gnngraph.jl

Lines changed: 94 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ const ADJMAT_T = AbstractMatrix
1111
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212

1313
"""
14-
GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, graph_indicator, dir])
15-
GNNGraph(g::GNNGraph; [nf, ef, gf])
14+
GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir])
15+
GNNGraph(g::GNNGraph; [ndata, edata, gdata])
1616
1717
A type representing a graph structure and storing also arrays
1818
that contain features associated to nodes, edges, and the whole graph.
@@ -50,10 +50,10 @@ from the LightGraphs' graph library can be used on it.
5050
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
5151
Possible values are `:out` and `:in`. Default `:out`.
5252
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
53-
- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
54-
- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
55-
- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
56-
- `gf`: Global features. Default `nothing`.
53+
- `graph_indicator`. For batched graphs, a vector containing the graph assigment of each node. Default `nothing`.
54+
- `ndata`: Node features. A named tuple of arrays whose last dimension has size num_nodes.
55+
- `edata`: Edge features. A named tuple of arrays whose whose last dimension has size num_edges.
56+
- `gdata`: Global features. A named tuple of arrays whose has size num_graphs.
5757
5858
# Usage.
5959
@@ -77,7 +77,7 @@ g = GNNGraph(s, t)
7777
g = GNNGraph(erdos_renyi(100, 20))
7878
7979
# Copy graph while also adding node features
80-
g = GNNGraph(g, nf=rand(100, 5))
80+
g = GNNGraph(g, ndata = (x = rand(100, g.num_nodes),))
8181
8282
# Send to gpu
8383
g = g |> gpu
@@ -86,38 +86,28 @@ g = g |> gpu
8686
# Both source and target are vectors of length num_edges
8787
source, target = edge_index(g)
8888
```
89-
90-
See also [`graph`](@ref), [`edge_index`](@ref), [`node_feature`](@ref), [`edge_feature`](@ref), and [`global_feature`](@ref)
9189
"""
9290
struct GNNGraph{T<:Union{COO_T,ADJMAT_T}}
9391
graph::T
9492
num_nodes::Int
9593
num_edges::Int
9694
num_graphs::Int
9795
graph_indicator
98-
nf
99-
ef
100-
gf
101-
## possible future property stores
102-
# ndata::Dict{String, Any} # https://github.com/FluxML/Zygote.jl/issues/717
103-
# edata::Dict{String, Any}
104-
# gdata::Dict{String, Any}
96+
ndata::NamedTuple
97+
edata::NamedTuple
98+
gdata::NamedTuple
10599
end
106100

107101
@functor GNNGraph
108102

109103
function GNNGraph(data;
110104
num_nodes = nothing,
111-
num_graphs = 1,
112105
graph_indicator = nothing,
113106
graph_type = :coo,
114107
dir = :out,
115-
nf = nothing,
116-
ef = nothing,
117-
gf = nothing,
118-
# ndata = Dict{String, Any}(),
119-
# edata = Dict{String, Any}(),
120-
# gdata = Dict{String, Any}()
108+
ndata = (;),
109+
edata = (;),
110+
gdata = (;),
121111
)
122112

123113
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
@@ -133,18 +123,20 @@ function GNNGraph(data;
133123

134124
num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1
135125

136-
## Possible future implementation of feature maps.
137-
## Currently this doesn't play well with zygote due to
138-
## https://github.com/FluxML/Zygote.jl/issues/717
139-
# ndata["x"] = nf
140-
# edata["e"] = ef
141-
# gdata["g"] = gf
126+
ndata = normalize_graphdata(ndata, :X)
127+
edata = normalize_graphdata(edata, :E)
128+
gdata = normalize_graphdata(gdata, :U)
142129

143-
GNNGraph(g, num_nodes, num_edges,
144-
num_graphs, graph_indicator,
145-
nf, ef, gf)
130+
GNNGraph(g,
131+
num_nodes, num_edges, num_graphs,
132+
graph_indicator,
133+
ndata, edata, gdata)
146134
end
147135

136+
normalize_graphdata(data::NamedTuple, s) = data
137+
normalize_graphdata(data::Nothing, s) = NamedTuple()
138+
normalize_graphdata(data, s) = NamedTuple{(s,)}((data,))
139+
148140
# COO convenience constructors
149141
GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = GNNGraph((s, t, v); kws...)
150142
GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
@@ -154,14 +146,19 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
154146
function GNNGraph(g::AbstractGraph; kws...)
155147
s = LightGraphs.src.(LightGraphs.edges(g))
156148
t = LightGraphs.dst.(LightGraphs.edges(g))
157-
GNNGraph((s, t); num_nodes = nv(g), kws...)
149+
GNNGraph((s, t); num_nodes = LightGraphs.nv(g), kws...)
158150
end
159151

160-
function GNNGraph(g::GNNGraph;
161-
nf=node_feature(g), ef=edge_feature(g), gf=global_feature(g))
162-
# ndata=copy(g.ndata), edata=copy(g.edata), gdata=copy(g.gdata), # copy keeps the refs to old data
152+
function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata)
153+
154+
ndata = normalize_graphdata(ndata, :X)
155+
edata = normalize_graphdata(edata, :E)
156+
gdata = normalize_graphdata(gdata, :U)
163157

164-
GNNGraph(g.graph, g.num_nodes, g.num_edges, g.num_graphs, g.graph_indicator, nf, ef, gf) # ndata, edata, gdata,
158+
GNNGraph(g.graph,
159+
g.num_nodes, g.num_edges, g.num_graphs,
160+
g.graph_indicator,
161+
ndata, edata, gdata)
165162
end
166163

167164

@@ -266,44 +263,6 @@ function LightGraphs.degree(g::GNNGraph{<:ADJMAT_T}, T=Int; dir=:out)
266263
return dir == :out ? vec(sum(A, dims=2)) : vec(sum(A, dims=1))
267264
end
268265

269-
# node_feature(g::GNNGraph) = g.ndata["x"]
270-
# edge_feature(g::GNNGraph) = g.edata["e"]
271-
# global_feature(g::GNNGraph) = g.gdata["g"]
272-
273-
274-
"""
275-
node_feature(g::GNNGraph)
276-
277-
Return the node features of `g`.
278-
"""
279-
node_feature(g::GNNGraph) = g.nf
280-
281-
"""
282-
edge_feature(g::GNNGraph)
283-
284-
Return the edge features of `g`.
285-
"""
286-
edge_feature(g::GNNGraph) = g.ef
287-
288-
"""
289-
global_feature(g::GNNGraph)
290-
291-
Return the global features of `g`.
292-
"""
293-
global_feature(g::GNNGraph) = g.gf
294-
295-
# function Base.getproperty(g::GNNGraph, sym::Symbol)
296-
# if sym === :nf
297-
# return g.ndata["x"]
298-
# elseif sym === :ef
299-
# return g.edata["e"]
300-
# elseif sym === :gf
301-
# return g.gdata["g"]
302-
# else # fallback to getfield
303-
# return getfield(g, sym)
304-
# end
305-
# end
306-
307266
function LightGraphs.laplacian_matrix(g::GNNGraph, T::DataType=Int; dir::Symbol=:out)
308267
A = adjacency_matrix(g, T; dir=dir)
309268
D = Diagonal(vec(sum(A; dims=2)))
@@ -376,41 +335,44 @@ self-loops will obtain a second self-loop.
376335
"""
377336
function add_self_loops(g::GNNGraph{<:COO_T})
378337
s, t = edge_index(g)
379-
@assert edge_feature(g) === nothing
338+
@assert g.edata === (;)
380339
@assert edge_weight(g) === nothing
381340
n = g.num_nodes
382341
nodes = convert(typeof(s), [1:n;])
383342
s = [s; nodes]
384343
t = [t; nodes]
385344

386-
GNNGraph((s, t, nothing), g.num_nodes, length(s),
387-
g.num_graphs, g.graph_indicator,
388-
node_feature(g), edge_feature(g), global_feature(g))
345+
GNNGraph((s, t, nothing),
346+
g.num_nodes, length(s), g.num_graphs,
347+
g.graph_indicator,
348+
g.ndata, g.edata, g.gdata)
389349
end
390350

391-
function add_self_loops(g::GNNGraph{<:ADJMAT_T}; add_to_existing=true)
392-
A = graph(g)
393-
@assert edge_feature(g) === nothing
351+
function add_self_loops(g::GNNGraph{<:ADJMAT_T})
352+
A = adjaceny_matrix(g)
353+
@assert g.edata === (;)
394354
A += I
395355
num_edges = g.num_edges + g.num_nodes
396-
GNNGraph(A, g.num_nodes, num_edges,
397-
g.num_graphs, g.graph_indicator,
398-
node_feature(g), edge_feature(g), global_feature(g))
356+
GNNGraph(A,
357+
g.num_nodes, num_edges, g.num_graphs,
358+
g.graph_indicator,
359+
g.ndata, g.edata, g.gdata)
399360
end
400361

401362
function remove_self_loops(g::GNNGraph{<:COO_T})
402363
s, t = edge_index(g)
403364
# TODO remove these constraints
404-
@assert edge_feature(g) === nothing
365+
@assert g.edata === (;)
405366
@assert edge_weight(g) === nothing
406367

407368
mask_old_loops = s .!= t
408369
s = s[mask_old_loops]
409370
t = t[mask_old_loops]
410371

411-
GNNGraph((s, t, nothing), g.num_nodes, length(s),
412-
g.num_graphs, g.graph_indicator,
413-
node_feature(g), edge_feature(g), global_feature(g))
372+
GNNGraph((s, t, nothing),
373+
g.num_nodes, length(s), g.num_graphs,
374+
g.graph_indicator,
375+
g.ndata, g.edata, g.gdata)
414376
end
415377

416378
function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
@@ -425,14 +387,12 @@ function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
425387
ind2 = isnothing(g2.graph_indicator) ? fill!(similar(s2, Int, nv2), 1) : g2.graph_indicator
426388
graph_indicator = vcat(ind1, g1.num_graphs .+ ind2)
427389

428-
GNNGraph(
429-
(s, t, w),
430-
nv1 + nv2, g1.num_edges + g2.num_edges,
431-
g1.num_graphs + g2.num_graphs, graph_indicator,
432-
cat_features(node_feature(g1), node_feature(g2)),
433-
cat_features(edge_feature(g1), edge_feature(g2)),
434-
cat_features(global_feature(g1), global_feature(g2)),
435-
)
390+
GNNGraph((s, t, w),
391+
nv1 + nv2, g1.num_edges + g2.num_edges, g1.num_graphs + g2.num_graphs,
392+
graph_indicator,
393+
cat_features(g1.ndata, g2.ndata),
394+
cat_features(g1.edata, g2.edata),
395+
cat_features(g1.gdata, g2.gdata))
436396
end
437397

438398
### Cat public interfaces #############
@@ -490,9 +450,9 @@ function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
490450
s = [nodemap[i] for i in s[edge_mask]]
491451
t = [nodemap[i] for i in t[edge_mask]]
492452
w = isnothing(w) ? nothing : w[edge_mask]
493-
nf = isnothing(g.nf) ? nothing : g.nf[:,node_mask]
494-
ef = isnothing(g.ef) ? nothing : g.ef[:,edge_mask]
495-
gf = isnothing(g.gf) ? nothing : g.gf[:,i]
453+
ndata = getobs(g.ndata, node_mask)
454+
edata = getobs(g.ndata, edge_mask)
455+
gdata = getobs(g.gdata, i)
496456

497457
num_nodes = length(graph_indicator)
498458
num_edges = length(s)
@@ -501,10 +461,43 @@ function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
501461
gnew = GNNGraph((s,t,w),
502462
num_nodes, num_edges, num_graphs,
503463
graph_indicator,
504-
nf, ef, gf)
464+
ndata, edata, gdata)
505465
return gnew, nodes
506466
end
507467

468+
### TO DEPRECATE ?? ###
469+
function node_features(g::GNNGraph)
470+
if isempty(g.ndata)
471+
return nothing
472+
elseif length(g.ndata) > 1
473+
@error "Multiple feature arrays, access directly with g.ndata.X"
474+
else
475+
return g.ndata[1]
476+
end
477+
end
478+
479+
function edge_features(g::GNNGraph)
480+
if isempty(g.edata)
481+
return nothing
482+
elseif length(g.edata) > 1
483+
@error "Multiple feature arrays, access directly with g.edata.E"
484+
else
485+
return g.edata[1]
486+
end
487+
end
488+
489+
function global_features(g::GNNGraph)
490+
if isempty(g.gdata)
491+
return nothing
492+
elseif length(g.gdata) > 1
493+
@error "Multiple feature arrays, access directly with g.gdata.U"
494+
else
495+
return g.gdata[1]
496+
end
497+
end
498+
#########
499+
500+
508501
@non_differentiable normalized_laplacian(x...)
509502
@non_differentiable normalized_adjacency(x...)
510503
@non_differentiable scaled_laplacian(x...)

src/layers/basic.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ See also [`GNNChain`](@ref).
77
"""
88
abstract type GNNLayer end
99

10+
#TODO extend to store also edge and global features
11+
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g)))
12+
1013
"""
1114
GNNChain(layers...)
1215
GNNChain(name = layer, ...)
@@ -39,7 +42,7 @@ julia> m(g, x)
3942
-0.0134364 -0.0120716 -0.0172505
4043
```
4144
"""
42-
struct GNNChain{T}
45+
struct GNNChain{T} <: GNNLayer
4346
layers::T
4447

4548
GNNChain(xs...) = new{typeof(xs)}(xs)

0 commit comments

Comments
 (0)