Skip to content

Commit ff91294

Browse files
Merge pull request #18 from CarloLucibello/cl/tudataset
add graph classification example
2 parents cb515a4 + e4dafbe commit ff91294

File tree

11 files changed

+250
-24
lines changed

11 files changed

+250
-24
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14-
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1514
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1615
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1716
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ Some of its noticeable features are the following:
1313
* CUDA support.
1414
* Integrated with the JuliaGraphs ecosystem.
1515
* Supports generic graph neural network architectures.
16-
* Easy to define custom graph convolutional layers.
16+
* Operation on batched graphs.
17+
* Easily define your custom graph convolutional layers.
1718

1819
## Installation
1920

examples/Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4+
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
5+
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
6+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
7+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
8+
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# An example of graph classification
2+
3+
using Flux
4+
using Flux: @functor, dropout, onecold, onehotbatch, getindex
5+
using Flux.Losses: logitbinarycrossentropy
6+
using Flux.Data: DataLoader
7+
using GraphNeuralNetworks
8+
using MLDatasets: TUDataset
9+
using Statistics, Random
10+
using CUDA
11+
CUDA.allowscalar(false)
12+
13+
function eval_loss_accuracy(model, data_loader, device)
14+
loss = 0.
15+
acc = 0.
16+
ntot = 0
17+
for (g, X, y) in data_loader
18+
g, X, y = g |> device, X |> device, y |> device
19+
n = length(y)
20+
= model(g, X) |> vec
21+
loss += logitbinarycrossentropy(ŷ, y) * n
22+
acc += mean((2 .*.- 1) .* (2 .* y .- 1) .> 0) * n
23+
ntot += n
24+
end
25+
return (loss = round(loss/ntot, digits=4), acc = round(acc*100/ntot, digits=2))
26+
end
27+
28+
struct GNNData
29+
g
30+
X
31+
y
32+
end
33+
34+
Base.getindex(data::GNNData, i::Int) = getindex(data, [i])
35+
36+
function Base.getindex(data::GNNData, i::AbstractVector)
37+
sg, nodemap = subgraph(data.g, i)
38+
return (sg, data.X[:,nodemap], data.y[i])
39+
end
40+
41+
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
42+
Flux.Data._nobs(data::GNNData) = data.g.num_graphs
43+
Flux.Data._getobs(data::GNNData, i) = data[i]
44+
45+
function process_dataset(data)
46+
g = GNNGraph(data.source, data.target, num_nodes=data.num_nodes, graph_indicator=data.graph_indicator)
47+
X = Array{Float32}(onehotbatch(data.node_labels, 0:6))
48+
# The dataset also has edge features but we won't be using them
49+
# E = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
50+
y = (1 .+ Array{Float32}(data.graph_labels)) ./ 2
51+
@assert all(([0,1]), y) # binary classification
52+
return GNNData(g, X, y)
53+
end
54+
55+
# arguments for the `train` function
56+
Base.@kwdef mutable struct Args
57+
η = 1f-3 # learning rate
58+
batchsize = 64 # batch size (number of graphs in each batch)
59+
epochs = 200 # number of epochs
60+
seed = 17 # set seed > 0 for reproducibility
61+
usecuda = true # if true use cuda (if available)
62+
nhidden = 128 # dimension of hidden features
63+
infotime = 10 # report every `infotime` epochs
64+
end
65+
66+
function train(; kws...)
67+
args = Args(; kws...)
68+
args.seed > 0 && Random.seed!(args.seed)
69+
70+
if args.usecuda && CUDA.functional()
71+
device = gpu
72+
args.seed > 0 && CUDA.seed!(args.seed)
73+
@info "Training on GPU"
74+
else
75+
device = cpu
76+
@info "Training on CPU"
77+
end
78+
79+
# LOAD DATA
80+
81+
NUM_TRAIN = 150
82+
full_data = TUDataset("MUTAG")
83+
84+
@info "MUTAG DATASET
85+
num_nodes: $(full_data.num_nodes)
86+
num_edges: $(full_data.num_edges)
87+
num_graphs: $(full_data.num_graphs)"
88+
89+
perm = randperm(full_data.num_graphs)
90+
dtrain = process_dataset(full_data[perm[1:NUM_TRAIN]])
91+
dtest = process_dataset(full_data[perm[NUM_TRAIN+1:end]])
92+
train_loader = DataLoader(dtrain, batchsize=args.batchsize, shuffle=true)
93+
test_loader = DataLoader(dtest, batchsize=args.batchsize, shuffle=false)
94+
95+
# DEFINE MODEL
96+
97+
nin = size(dtrain.X, 1)
98+
nhidden = args.nhidden
99+
100+
model = GNNChain(GraphConv(nin => nhidden, relu),
101+
Dropout(0.5),
102+
GraphConv(nhidden => nhidden, relu),
103+
GlobalPool(mean),
104+
Dense(nhidden, 1)) |> device
105+
106+
ps = Flux.params(model)
107+
opt = ADAM(args.η)
108+
109+
110+
# LOGGING FUNCTION
111+
112+
function report(epoch)
113+
train = eval_loss_accuracy(model, train_loader, device)
114+
test = eval_loss_accuracy(model, test_loader, device)
115+
println("Epoch: $epoch Train: $(train) Test: $(test)")
116+
end
117+
118+
# TRAIN
119+
120+
report(0)
121+
for epoch in 1:args.epochs
122+
for (g, X, y) in train_loader
123+
g, X, y = g |> device, X |> device, y |> device
124+
gs = Flux.gradient(ps) do
125+
= model(g, X) |> vec
126+
logitbinarycrossentropy(ŷ, y)
127+
end
128+
Flux.Optimise.update!(opt, ps, gs)
129+
end
130+
131+
epoch % args.infotime == 0 && report(epoch)
132+
end
133+
end
134+
135+
train()

examples/cora.jl renamed to examples/node_classification_cora.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Base.@kwdef mutable struct Args
2121
η = 1f-3 # learning rate
2222
epochs = 100 # number of epochs
2323
seed = 17 # set seed > 0 for reproducibility
24-
use_cuda = true # if true use cuda (if available)
24+
usecuda = true # if true use cuda (if available)
2525
nhidden = 128 # dimension of hidden features
2626
infotime = 10 # report every `infotime` epochs
2727
end
@@ -33,7 +33,7 @@ function train(; kws...)
3333
CUDA.seed!(args.seed)
3434
end
3535

36-
if args.use_cuda && CUDA.functional()
36+
if args.usecuda && CUDA.functional()
3737
device = gpu
3838
@info "Training on GPU"
3939
else

src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ export
2424
edge_index,
2525
node_feature, edge_feature, global_feature,
2626
adjacency_list, normalized_laplacian, scaled_laplacian,
27-
add_self_loops,
27+
add_self_loops, remove_self_loops,
28+
subgraph,
2829

2930
# from LightGraphs
3031
adjacency_matrix,

src/gnngraph.jl

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
1111
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212

1313
"""
14-
GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir])
14+
GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, graph_indicator, dir])
1515
GNNGraph(g::GNNGraph; [nf, ef, gf])
1616
1717
A type representing a graph structure and storing also arrays
@@ -23,6 +23,11 @@ is governed by `graph_type`.
2323
When constructed from another graph `g`, the internal graph representation
2424
is preserved and shared.
2525
26+
A `GNNGraph` can also represent multiple graphs batched togheter
27+
(see [`Flux.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)).
28+
The field `g.graph_indicator` contains the graph membership
29+
of each node.
30+
2631
A `GNNGraph` is a LightGraphs' `AbstractGraph`, therefore any functionality
2732
from the LightGraphs' graph library can be used on it.
2833
@@ -45,7 +50,6 @@ from the LightGraphs' graph library can be used on it.
4550
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
4651
Possible values are `:out` and `:in`. Default `:out`.
4752
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
48-
- `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
4953
- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
5054
- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
5155
- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
@@ -118,17 +122,17 @@ function GNNGraph(data;
118122

119123
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
120124
@assert dir [:in, :out]
125+
121126
if graph_type == :coo
122127
g, num_nodes, num_edges = to_coo(data; num_nodes, dir)
123128
elseif graph_type == :dense
124129
g, num_nodes, num_edges = to_dense(data; dir)
125130
elseif graph_type == :sparse
126131
g, num_nodes, num_edges = to_sparse(data; dir)
127132
end
128-
if num_graphs > 1
129-
@assert len(graph_indicator) = num_nodes "When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
130-
end
131-
133+
134+
num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1
135+
132136
## Possible future implementation of feature maps.
133137
## Currently this doesn't play well with zygote due to
134138
## https://github.com/FluxML/Zygote.jl/issues/717
@@ -149,8 +153,8 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
149153

150154
function GNNGraph(g::AbstractGraph; kws...)
151155
s = LightGraphs.src.(LightGraphs.edges(g))
152-
t = LightGraphs.dst.(LightGraphs.edges(g))
153-
GNNGraph((s, t); kws...)
156+
t = LightGraphs.dst.(LightGraphs.edges(g))
157+
GNNGraph((s, t); num_nodes = nv(g), kws...)
154158
end
155159

156160
function GNNGraph(g::GNNGraph;
@@ -431,19 +435,76 @@ function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
431435
)
432436
end
433437

434-
# Cat public interfaces
438+
### Cat public interfaces #############
439+
440+
"""
441+
blockdiag(xs::GNNGraph...)
442+
443+
Batch togheter multiple `GNNGraph`s into a single one
444+
containing the total number of nodes and edges of the original graphs.
445+
446+
Equivalent to [`Flux.batch`](@ref).
447+
"""
435448
function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
436-
@assert length(gothers) >= 1
437449
g = g1
438450
for go in gothers
439451
g = _catgraphs(g, go)
440452
end
441453
return g
442454
end
443455

456+
"""
457+
batch(xs::Vector{<:GNNGraph})
458+
459+
Batch togheter multiple `GNNGraph`s into a single one
460+
containing the total number of nodes and edges of the original graphs.
461+
462+
Equivalent to [`SparseArrays.blockdiag`](@ref).
463+
"""
444464
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
445465
#########################
446466

467+
"""
468+
subgraph(g::GNNGraph, i)
469+
470+
Return the subgraph of `g` induced by those nodes `v`
471+
for which `g.graph_indicator[v] ∈ i`. In other words, it
472+
extract the component graphs from a batched graph.
473+
474+
It also returns a vector `nodes` mapping the new nodes to the old ones.
475+
The node `i` in the subgraph corresponds to the node `nodes[i]` in `g`.
476+
"""
477+
subgraph(g::GNNGraph, i::Int) = subgraph(g::GNNGraph{<:COO_T}, [i])
478+
479+
function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
480+
node_mask = g.graph_indicator .∈ Ref(i)
481+
482+
nodes = (1:g.num_nodes)[node_mask]
483+
nodemap = Dict(v => vnew for (vnew, v) in enumerate(nodes))
484+
485+
graphmap = Dict(i => inew for (inew, i) in enumerate(i))
486+
graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]]
487+
488+
s, t, w = g.graph
489+
edge_mask = s .∈ Ref(nodes)
490+
s = [nodemap[i] for i in s[edge_mask]]
491+
t = [nodemap[i] for i in t[edge_mask]]
492+
w = isnothing(w) ? nothing : w[edge_mask]
493+
nf = isnothing(g.nf) ? nothing : g.nf[:,node_mask]
494+
ef = isnothing(g.ef) ? nothing : g.ef[:,edge_mask]
495+
gf = isnothing(g.gf) ? nothing : g.gf[:,i]
496+
497+
num_nodes = length(graph_indicator)
498+
num_edges = length(s)
499+
num_graphs = length(i)
500+
501+
gnew = GNNGraph((s,t,w),
502+
num_nodes, num_edges, num_graphs,
503+
graph_indicator,
504+
nf, ef, gf)
505+
return gnew, nodes
506+
end
507+
447508
@non_differentiable normalized_laplacian(x...)
448509
@non_differentiable normalized_adjacency(x...)
449510
@non_differentiable scaled_laplacian(x...)

src/layers/conv.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,9 @@ Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-
147147
148148
Performs:
149149
```math
150-
\mathbf{x}_i' = W^1 \mathbf{x}_i + \box_{j \in \mathcal{N}(i)} W^2 \mathbf{x}_j)
150+
\mathbf{x}_i' = W^1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W^2 \mathbf{x}_j)
151151
```
152+
152153
where the aggregation type is selected by `aggr`.
153154
154155
# Arguments
@@ -206,7 +207,7 @@ end
206207
concat=true,
207208
init=glorot_uniform
208209
bias=true,
209-
negative_slope=0.2)
210+
negative_slope=0.2f0)
210211
211212
Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.org/abs/1710.10903).
212213
@@ -216,7 +217,7 @@ Implements the operation
216217
```
217218
where the attention coefficient ``\alpha_{ij}`` is given by
218219
```math
219-
\alpha_{ij} = \frac{1}{z_i} exp(LeakyReLU(\mathbf{a}^T [W \mathbf{x}_i || W \mathbf{x}_j]))
220+
\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W \mathbf{x}_i || W \mathbf{x}_j]))
220221
```
221222
with ``z_i`` a normalization factor.
222223
@@ -301,7 +302,7 @@ Gated graph convolution layer from [Gated Graph Sequence Neural Networks](https:
301302
Implements the recursion
302303
```math
303304
\mathbf{h}^{(0)}_i = \mathbf{x}_i || \mathbf{0} \\
304-
\mathbf{h}^{(l)}_i = GRU(\mathbf{h}^{(l-1)}_i, \box_{j \in N(i)} W \mathbf{h}^{(l-1)}_j)
305+
\mathbf{h}^{(l)}_i = GRU(\mathbf{h}^{(l-1)}_i, \square_{j \in N(i)} W \mathbf{h}^{(l-1)}_j)
305306
```
306307
307308
where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing through GRU. The dimension of input ``\mathbf{x}_i`` needs to be less or equal to `out`.
@@ -369,7 +370,7 @@ Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clo
369370
370371
Performs the operation
371372
```math
372-
\mathbf{x}_i' = \box_{j \in N(i)} f(\mathbf{x}_i || \mathbf{x}_j - \mathbf{x}_i)
373+
\mathbf{x}_i' = \square_{j \in N(i)} f(\mathbf{x}_i || \mathbf{x}_j - \mathbf{x}_i)
373374
```
374375
375376
where `f` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.

src/layers/pool.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ X = rand(32, 10)
2424
pool(g, X) # => 32x1 matrix
2525
```
2626
"""
27-
struct GlobalPool{F}
27+
struct GlobalPool{F} <: GNNLayer
2828
aggr::F
2929
end
3030

0 commit comments

Comments
 (0)