Skip to content

Commit b7bdaf4

Browse files
models docs
1 parent 77ff798 commit b7bdaf4

File tree

6 files changed

+64
-8
lines changed

6 files changed

+64
-8
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1415
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1516
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1617
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ makedocs(;
1616
"Basic Layers" => "api/basic.md",
1717
"Convolutional Layers" => "api/conv.md",
1818
"Pooling Layers" => "api/pool.md",
19+
"NNlib" => "api/nnlib.md",
1920
],
2021
"Developer Notes" => "dev.md",
2122
],

docs/src/api/nnlib.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# NNlib
2+
3+
Primitive functions implemented in NNlib.jl.
4+
5+
```@docs
6+
NNlib.gather!
7+
NNlib.gather
8+
NNlib.scatter!
9+
NNlib.scatter
10+
```

docs/src/messagepassing.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
# Message Passing
22

3+
```@docs
4+
message
5+
update
6+
propagate
7+
```

docs/src/models.md

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,47 @@
11
# Models
22

3+
GraphNeuralNetworks.jl provides common graph convolutional layers by which you can assemble arbitrarily deep or complex models. GNN layers are compatible with
4+
Flux.jl ones, therefore expert Flux's users should be immediately able to define and train
5+
their models.
6+
7+
In what follows, we discuss two different styles for model creation:
8+
the *explicit modeling* style, more verbose but more flexible,
9+
and the *implicity modeling* style based on [`GNNChain`](@ref), more concise but less flexible.
10+
311
## Explicit modeling
412

13+
In the explicit modeling style, the model is created according to the following steps:
14+
15+
1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields.
16+
2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
17+
3. Optionally define a convenience constructor for your model.
18+
4. Define the forward pass by implementing the function call method for your type
19+
5. Instantiate the model.
20+
21+
Here is an example of this construction:
522
```julia
6-
using Flux, GraphNeuralNetworks
23+
using Flux, LightGraphs, GraphNeuralNetworks
724
using Flux: @functor
825

9-
struct GNN
26+
struct GNN # step 1
1027
conv1
1128
bn
1229
conv2
1330
dropout
1431
dense
1532
end
1633

17-
@functor GNN
34+
@functor GNN # step 2
1835

19-
function GNN(din, d, dout)
36+
function GNN(din::Int, d::Int, dout::Int) # step 3
2037
GNN(GCNConv(din => d),
2138
BatchNorm(d),
2239
GraphConv(d => d, relu),
2340
Dropout(0.5),
2441
Dense(d, dout))
2542
end
2643

27-
function (model::GNN)(g::GNNGraph, x)
44+
function (model::GNN)(g::GNNGraph, x) # step 4
2845
x = model.conv1(g, x)
2946
x = relu.(model.bn(x))
3047
x = model.conv2(g, x)
@@ -34,19 +51,38 @@ function (model::GNN)(g::GNNGraph, x)
3451
end
3552

3653
din, d, dout = 3, 4, 2
37-
g = GNNGraph(random_regular_graph(10, 4), graph_type=GRAPH_T)
54+
g = GNNGraph(random_regular_graph(10, 4))
3855
X = randn(Float32, din, 10)
39-
model = GNN(din, d, dout)
56+
model = GNN(din, d, dout) # step 5
4057
y = model(g, X)
4158
```
4259

43-
## Compact modeling with GNNChains
60+
## Implicit modeling with GNNChains
61+
62+
While very flexible, the way in which we defined `GNN` model definition in last section is a bit verbose.
63+
In order to simplify things, we provide the [`GNNChain`](@ref) type. It is very similar
64+
to Flux's well known `Chain`. It allows to compose layers in a sequential fashion as Chain
65+
does, propagating the output of each layer to the next one. In addition, `GNNChain`
66+
handles propagates the input graph as well, providing it as a first argument
67+
to layers subtyping the [`GNNLayer`](@ref) abstract type.
68+
69+
Using `GNNChain`, the previous example becomes
4470

4571
```julia
72+
using Flux, LightGraphs, GraphNeuralNetworks
73+
74+
din, d, dout = 3, 4, 2
75+
g = GNNGraph(random_regular_graph(10, 4))
76+
X = randn(Float32, din, 10)
77+
4678
model = GNNChain(GCNConv(din => d),
4779
BatchNorm(d),
4880
x -> relu.(x),
4981
GraphConv(d => d, relu),
5082
Dropout(0.5),
5183
Dense(d, dout))
84+
85+
y = model(g, X)
5286
```
87+
88+
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.

src/layers/conv.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,9 @@ end
334334
message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
335335
update(l::GatedGraphConv, m, x) = m
336336

337+
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
338+
@non_differentiable fill!(x...)
339+
337340
function (ggc::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
338341
check_num_nodes(g, H)
339342
m, n = size(H)

0 commit comments

Comments
 (0)