Skip to content

Commit b4a8675

Browse files
Merge pull request #15 from CarloLucibello/cl/gnnchain
add GNNLayer, GNNChain, GNNGraph
2 parents 275b39f + f40e050 commit b4a8675

31 files changed

+858
-690
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
15+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1416
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1517
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1618
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -24,6 +26,7 @@ DataStructures = "0.18"
2426
Flux = "0.12"
2527
KrylovKit = "0.5"
2628
LightGraphs = "1.3"
29+
MacroTools = "0.5"
2730
NNlib = "0.7"
2831
NNlibCUDA = "0.1"
2932
julia = "1.6"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
34
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
5+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

docs/make.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using GraphNeuralNetworks
1+
using Flux, NNlib, GraphNeuralNetworks
22
using Documenter
33

44
DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, :(using GraphNeuralNetworks); recursive=true)
@@ -7,14 +7,17 @@ makedocs(;
77
modules=[GraphNeuralNetworks],
88
sitename = "GraphNeuralNetworks.jl",
99
pages = ["Home" => "index.md",
10-
"Graphs" => "graphs.md",
11-
"Message passing" => "messagepassing.md",
12-
"Building models" => "models.md",
10+
"GNNGraph" => "gnngraph.md",
11+
"Message Passing" => "messagepassing.md",
12+
"Model Building" => "models.md",
1313
"API Reference" =>
1414
[
15-
"Graphs" => "api/graphs.md",
15+
"GNNGraph" => "api/gnngraph.md",
16+
"Basic Layers" => "api/basic.md",
1617
"Convolutional Layers" => "api/conv.md",
1718
"Pooling Layers" => "api/pool.md",
19+
"Message Passing" => "api/messagepassing.md",
20+
"NNlib" => "api/nnlib.md",
1821
],
1922
"Developer Notes" => "dev.md",
2023
],

docs/src/api/basic.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Basic Layers
2+
3+
```@index
4+
Order = [:type, :function]
5+
Pages = ["api/basics.md"]
6+
```
7+
8+
```@autodocs
9+
Modules = [GraphNeuralNetworks]
10+
Pages = ["layers/basic.jl"]
11+
Private = false
12+
```

docs/src/api/graphs.md renamed to docs/src/api/gnngraph.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ Pages = ["api/graphs.md"]
77

88
```@autodocs
99
Modules = [GraphNeuralNetworks]
10-
Pages = ["featuredgraph.jl"]
10+
Pages = ["gnngraph.jl"]
1111
Private = false
1212
```

docs/src/api/messagepassing.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Message Passing
2+
3+
```@docs
4+
GraphNeuralNetworks.message
5+
GraphNeuralNetworks.update
6+
GraphNeuralNetworks.propagate
7+
```

docs/src/api/nnlib.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# NNlib
2+
3+
Primitive functions implemented in NNlib.jl.
4+
5+
```@docs
6+
NNlib.gather!
7+
NNlib.gather
8+
NNlib.scatter!
9+
NNlib.scatter
10+
```
File renamed without changes.

docs/src/models.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,88 @@
11
# Models
2+
3+
GraphNeuralNetworks.jl provides common graph convolutional layers by which you can assemble arbitrarily deep or complex models. GNN layers are compatible with
4+
Flux.jl ones, therefore expert Flux's users should be immediately able to define and train
5+
their models.
6+
7+
In what follows, we discuss two different styles for model creation:
8+
the *explicit modeling* style, more verbose but more flexible,
9+
and the *implicit modeling* style based on [`GNNChain`](@ref), more concise but less flexible.
10+
11+
## Explicit modeling
12+
13+
In the explicit modeling style, the model is created according to the following steps:
14+
15+
1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields.
16+
2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
17+
3. Optionally define a convenience constructor for your model.
18+
4. Define the forward pass by implementing the function call method for your type
19+
5. Instantiate the model.
20+
21+
Here is an example of this construction:
22+
```julia
23+
using Flux, LightGraphs, GraphNeuralNetworks
24+
using Flux: @functor
25+
26+
struct GNN # step 1
27+
conv1
28+
bn
29+
conv2
30+
dropout
31+
dense
32+
end
33+
34+
@functor GNN # step 2
35+
36+
function GNN(din::Int, d::Int, dout::Int) # step 3
37+
GNN(GCNConv(din => d),
38+
BatchNorm(d),
39+
GraphConv(d => d, relu),
40+
Dropout(0.5),
41+
Dense(d, dout))
42+
end
43+
44+
function (model::GNN)(g::GNNGraph, x) # step 4
45+
x = model.conv1(g, x)
46+
x = relu.(model.bn(x))
47+
x = model.conv2(g, x)
48+
x = model.dropout(x)
49+
x = model.dense(x)
50+
return x
51+
end
52+
53+
din, d, dout = 3, 4, 2
54+
g = GNNGraph(random_regular_graph(10, 4))
55+
X = randn(Float32, din, 10)
56+
model = GNN(din, d, dout) # step 5
57+
y = model(g, X)
58+
```
59+
60+
## Implicit modeling with GNNChains
61+
62+
While very flexible, the way in which we defined `GNN` model definition in last section is a bit verbose.
63+
In order to simplify things, we provide the [`GNNChain`](@ref) type. It is very similar
64+
to Flux's well known `Chain`. It allows to compose layers in a sequential fashion as Chain
65+
does, propagating the output of each layer to the next one. In addition, `GNNChain`
66+
handles propagates the input graph as well, providing it as a first argument
67+
to layers subtyping the [`GNNLayer`](@ref) abstract type.
68+
69+
Using `GNNChain`, the previous example becomes
70+
71+
```julia
72+
using Flux, LightGraphs, GraphNeuralNetworks
73+
74+
din, d, dout = 3, 4, 2
75+
g = GNNGraph(random_regular_graph(10, 4))
76+
X = randn(Float32, din, 10)
77+
78+
model = GNNChain(GCNConv(din => d),
79+
BatchNorm(d),
80+
x -> relu.(x),
81+
GraphConv(d => d, relu),
82+
Dropout(0.5),
83+
Dense(d, dout))
84+
85+
y = model(g, X)
86+
```
87+
88+
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.

examples/cora.jl

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,8 @@ using Statistics, Random
99
using CUDA
1010
CUDA.allowscalar(false)
1111

12-
struct GNN
13-
conv1
14-
conv2
15-
dense
16-
end
17-
18-
@functor GNN
19-
20-
function GNN(; nin, nhidden, nout)
21-
GNN(GCNConv(nin => nhidden, relu),
22-
GCNConv(nhidden => nhidden, relu),
23-
Dense(nhidden, nout))
24-
end
25-
26-
function (net::GNN)(fg, x)
27-
x = net.conv1(fg, x)
28-
x = dropout(x, 0.5)
29-
x = net.conv2(fg, x)
30-
x = net.dense(x)
31-
return x
32-
end
33-
34-
function eval_loss_accuracy(X, y, ids, model, fg)
35-
= model(fg, X)
12+
function eval_loss_accuracy(X, y, ids, model, g)
13+
= model(g, X)
3614
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
3715
acc = mean(onecold(ŷ[:,ids] |> cpu) .== onecold(y[:,ids] |> cpu))
3816
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
@@ -63,34 +41,41 @@ function train(; kws...)
6341
@info "Training on CPU"
6442
end
6543

44+
# LOAD DATA
6645
data = Cora.dataset()
67-
fg = FeaturedGraph(data.adjacency_list) |> device
46+
g = GNNGraph(data.adjacency_list) |> device
6847
X = data.node_features |> device
6948
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
7049
train_ids = data.train_indices |> device
7150
val_ids = data.val_indices |> device
7251
test_ids = data.test_indices |> device
7352
ytrain = y[:,train_ids]
7453

75-
model = GNN(nin=size(X,1),
76-
nhidden=args.nhidden,
77-
nout=data.num_classes) |> device
54+
nin, nhidden, nout = size(X,1), args.nhidden, data.num_classes
55+
56+
## DEFINE MODEL
57+
model = GNNChain(GCNConv(nin => nhidden, relu),
58+
Dropout(0.5),
59+
GCNConv(nhidden => nhidden, relu),
60+
Dense(nhidden, nout)) |> device
61+
7862
ps = Flux.params(model)
7963
opt = ADAM(args.η)
8064

81-
@info "NUM NODES: $(fg.num_nodes) NUM EDGES: $(fg.num_edges)"
65+
@info "NUM NODES: $(g.num_nodes) NUM EDGES: $(g.num_edges)"
8266

67+
## LOGGING FUNCTION
8368
function report(epoch)
84-
train = eval_loss_accuracy(X, y, train_ids, model, fg)
85-
test = eval_loss_accuracy(X, y, test_ids, model, fg)
69+
train = eval_loss_accuracy(X, y, train_ids, model, g)
70+
test = eval_loss_accuracy(X, y, test_ids, model, g)
8671
println("Epoch: $epoch Train: $(train) Test: $(test)")
8772
end
8873

8974
## TRAINING
9075
report(0)
9176
for epoch in 1:args.epochs
9277
gs = Flux.gradient(ps) do
93-
= model(fg, X)
78+
= model(g, X)
9479
logitcrossentropy(ŷ[:,train_ids], ytrain)
9580
end
9681

0 commit comments

Comments
 (0)