Skip to content

Commit edd95b2

Browse files
docs
1 parent 011c956 commit edd95b2

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

docs/src/messagepassing.md

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,48 @@ where ``\phi`` is expressed by the [`compute_message`](@ref) function,
1818
``\gamma_x`` and ``\gamma_e`` by [`update_node`](@ref) and [`update_edge`](@ref)
1919
respectively.
2020

21-
See [`GraphConv`](ref) and [`GATConv`](ref)'s implementations as usage examples.
21+
## An example: implementing the GCNConv
22+
23+
Let's (re-)implement the [`GCNConv`](@ref) layer use the message passing framework.
24+
The convolution reads
25+
```math
26+
27+
```math
28+
\mathbf{x}'_i = \sum_{j\in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
29+
```
30+
where ``c_{ij} = \sqrt{N(i)\,N(j)}``. We will also add a bias and an activation function.
31+
32+
```julia
33+
using Flux
34+
using GraphNeuralNetworks
35+
import GraphNeuralNetworks: compute_message, update_node, propagate
36+
37+
struct GCNConv{A<:AbstractMatrix, B, F} <: GNNLayer
38+
weight::A
39+
bias::B
40+
σ::F
41+
end
42+
43+
Flux.@functor GCNConv
44+
45+
function GCNConv(ch::Pair{Int,Int}, σ=identity; init=glorot_uniform)
46+
in, out = ch
47+
W = init(out, in)
48+
b = zeros(Float32, out)
49+
GCNConv(W, b, σ)
50+
end
51+
52+
compute_message(l::GCNConv, xi, xj, eij) = l.weight * xj
53+
update_node(l::GCNConv, m, x) = m
54+
55+
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
56+
g = add_self_loops(g)
57+
c = 1 ./ sqrt.(degree(g, T, dir=:in))
58+
x = x .* c'
59+
x, _ = propagate(l, g, +, x)
60+
x = x .* c'
61+
return l.σ.(x .+ l.bias)
62+
end
63+
```
64+
65+

0 commit comments

Comments
 (0)