@@ -18,4 +18,48 @@ where ``\phi`` is expressed by the [`compute_message`](@ref) function,
18
18
`` \gamma_x `` and `` \gamma_e `` by [ ` update_node ` ] ( @ref ) and [ ` update_edge ` ] ( @ref )
19
19
respectively.
20
20
21
- See [ ` GraphConv ` ] ( ref ) and [ ` GATConv ` ] ( ref ) 's implementations as usage examples.
21
+ ## An example: implementing the GCNConv
22
+
23
+ Let's (re-)implement the [ ` GCNConv ` ] ( @ref ) layer use the message passing framework.
24
+ The convolution reads
25
+ ``` math
26
+
27
+ ```math
28
+ \mathbf{x}'_i = \sum_{j\in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
29
+ ```
30
+ where `` c_{ij} = \sqrt{N(i)\,N(j)} `` . We will also add a bias and an activation function.
31
+
32
+ ``` julia
33
+ using Flux
34
+ using GraphNeuralNetworks
35
+ import GraphNeuralNetworks: compute_message, update_node, propagate
36
+
37
+ struct GCNConv{A<: AbstractMatrix , B, F} <: GNNLayer
38
+ weight:: A
39
+ bias:: B
40
+ σ:: F
41
+ end
42
+
43
+ Flux. @functor GCNConv
44
+
45
+ function GCNConv (ch:: Pair{Int,Int} , σ= identity; init= glorot_uniform)
46
+ in, out = ch
47
+ W = init (out, in)
48
+ b = zeros (Float32, out)
49
+ GCNConv (W, b, σ)
50
+ end
51
+
52
+ compute_message (l:: GCNConv , xi, xj, eij) = l. weight * xj
53
+ update_node (l:: GCNConv , m, x) = m
54
+
55
+ function (l:: GCNConv )(g:: GNNGraph , x:: AbstractMatrix{T} ) where T
56
+ g = add_self_loops (g)
57
+ c = 1 ./ sqrt .(degree (g, T, dir= :in ))
58
+ x = x .* c'
59
+ x, _ = propagate (l, g, + , x)
60
+ x = x .* c'
61
+ return l. σ .(x .+ l. bias)
62
+ end
63
+ ```
64
+
65
+
0 commit comments