You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
GraphNeuralNetworks.jl provides common graph convolutional layers by which you can assemble arbitrarily deep or complex models. GNN layers are compatible with
4
+
Flux.jl ones, therefore expert Flux's users should be immediately able to define and train
5
+
their models.
6
+
7
+
In what follows, we discuss two different styles for model creation:
8
+
the *explicit modeling* style, more verbose but more flexible,
9
+
and the *implicity modeling* style based on [`GNNChain`](@ref), more concise but less flexible.
10
+
3
11
## Explicit modeling
4
12
13
+
In the explicit modeling style, the model is created according to the following steps:
14
+
15
+
1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields.
16
+
2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
17
+
3. Optionally define a convenience constructor for your model.
18
+
4. Define the forward pass by implementing the function call method for your type
19
+
5. Instantiate the model.
20
+
21
+
Here is an example of this construction:
5
22
```julia
6
-
using Flux, GraphNeuralNetworks
23
+
using Flux, LightGraphs, GraphNeuralNetworks
7
24
using Flux:@functor
8
25
9
-
struct GNN
26
+
struct GNN# step 1
10
27
conv1
11
28
bn
12
29
conv2
13
30
dropout
14
31
dense
15
32
end
16
33
17
-
@functor GNN
34
+
@functor GNN# step 2
18
35
19
-
functionGNN(din, d, dout)
36
+
functionGNN(din::Int, d::Int, dout::Int) # step 3
20
37
GNN(GCNConv(din => d),
21
38
BatchNorm(d),
22
39
GraphConv(d => d, relu),
23
40
Dropout(0.5),
24
41
Dense(d, dout))
25
42
end
26
43
27
-
function (model::GNN)(g::GNNGraph, x)
44
+
function (model::GNN)(g::GNNGraph, x)# step 4
28
45
x = model.conv1(g, x)
29
46
x =relu.(model.bn(x))
30
47
x = model.conv2(g, x)
@@ -34,19 +51,38 @@ function (model::GNN)(g::GNNGraph, x)
34
51
end
35
52
36
53
din, d, dout =3, 4, 2
37
-
g =GNNGraph(random_regular_graph(10, 4), graph_type=GRAPH_T)
54
+
g =GNNGraph(random_regular_graph(10, 4))
38
55
X =randn(Float32, din, 10)
39
-
model =GNN(din, d, dout)
56
+
model =GNN(din, d, dout)# step 5
40
57
y =model(g, X)
41
58
```
42
59
43
-
## Compact modeling with GNNChains
60
+
## Implicit modeling with GNNChains
61
+
62
+
While very flexible, the way in which we defined `GNN` model definition in last section is a bit verbose.
63
+
In order to simplify things, we provide the [`GNNChain`](@ref) type. It is very similar
64
+
to Flux's well known `Chain`. It allows to compose layers in a sequential fashion as Chain
65
+
does, propagating the output of each layer to the next one. In addition, `GNNChain`
66
+
handles propagates the input graph as well, providing it as a first argument
67
+
to layers subtyping the [`GNNLayer`](@ref) abstract type.
68
+
69
+
Using `GNNChain`, the previous example becomes
44
70
45
71
```julia
72
+
using Flux, LightGraphs, GraphNeuralNetworks
73
+
74
+
din, d, dout =3, 4, 2
75
+
g =GNNGraph(random_regular_graph(10, 4))
76
+
X =randn(Float32, din, 10)
77
+
46
78
model =GNNChain(GCNConv(din => d),
47
79
BatchNorm(d),
48
80
x ->relu.(x),
49
81
GraphConv(d => d, relu),
50
82
Dropout(0.5),
51
83
Dense(d, dout))
84
+
85
+
y =model(g, X)
52
86
```
87
+
88
+
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.
0 commit comments