Skip to content

Commit ae58c08

Browse files
committed
Docs GCNConv
1 parent ebd7929 commit ae58c08

File tree

1 file changed

+76
-3
lines changed

1 file changed

+76
-3
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,80 @@ _getstate(s::StatefulLuxLayer{Static.True}) = s.st
55
_getstate(s::StatefulLuxLayer{false}) = s.st_any
66
_getstate(s::StatefulLuxLayer{Static.False}) = s.st_any
77

8-
8+
@doc raw"""
9+
GCNConv(in => out, σ=identity; [init_weight, init_bias, use_bias, add_self_loops, use_edge_weight])
10+
11+
Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
12+
13+
Performs the operation
14+
```math
15+
\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j
16+
```
17+
where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees.
18+
19+
If the input graph has weighted edges and `use_edge_weight=true`, than ``a_{ij}`` will be computed as
20+
```math
21+
a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}}
22+
```
23+
24+
# Arguments
25+
26+
- `in`: Number of input features.
27+
- `out`: Number of output features.
28+
- `σ`: Activation function. Default `identity`.
29+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
30+
- `init_bias`: Bias initializer. Default `zeros32`.
31+
- `use_bias`: Add learnable bias. Default `true`.
32+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
33+
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
34+
If `add_self_loops=true` the new weights will be set to 1.
35+
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
36+
Default `false`.
37+
38+
# Forward
39+
40+
(::GCNConv)(g, x, [edge_weight], ps, st; norm_fn = d -> 1 ./ sqrt.(d), conv_weight=nothing)
41+
42+
Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`, optionally an edge weight vector and the parameter and state of the layer. Returns a node feature matrix of size
43+
`[out, num_nodes]`.
44+
45+
The `norm_fn` parameter allows for custom normalization of the graph convolution operation by passing a function as argument.
46+
By default, it computes ``\frac{1}{\sqrt{d}}`` i.e the inverse square root of the degree (`d`) of each node in the graph.
47+
If `conv_weight` is an `AbstractMatrix` of size `[out, in]`, then the convolution is performed using that weight matrix.
48+
49+
# Examples
50+
51+
```julia
52+
using GNNLux, Lux, Random
53+
# initialize random number generator
54+
rng = Random.default_rng()
55+
Random.seed!(rng, 0)
56+
# create data
57+
s = [1,1,2,3]
58+
t = [2,3,1,1]
59+
g = GNNGraph(s, t)
60+
x = randn(Float32, 3, g.num_nodes)
61+
62+
# create layer
63+
l = GCNConv(3 => 5)
64+
65+
# setup layer
66+
ps, st = LuxCore.setup(rng, l)
67+
68+
# forward pass
69+
y = l(g, x, ps, st) # size of the output first entry: 5 × num_nodes
70+
71+
# convolution with edge weights and custom normalization function
72+
w = [1.1, 0.1, 2.3, 0.5]
73+
custom_norm_fn(d) = 1 ./ sqrt.(d + 1) # Custom normalization function
74+
y = l(g, x, w, ps, st; norm_fn = custom_norm_fn)
75+
76+
# Edge weights can also be embedded in the graph.
77+
g = GNNGraph(s, t, w)
78+
l = GCNConv(3 => 5, use_edge_weight=true)
79+
y = l(g, x, ps, st) # same as l(g, x, w)
80+
```
81+
"""
982
@concrete struct GCNConv <: GNNLayer
1083
in_dims::Int
1184
out_dims::Int
@@ -18,7 +91,7 @@ _getstate(s::StatefulLuxLayer{Static.False}) = s.st_any
1891
end
1992

2093
function GCNConv(ch::Pair{Int, Int}, σ = identity;
21-
init_weight = glorot_uniform,
94+
init_weight = glorot_uniform,
2295
init_bias = zeros32,
2396
use_bias::Bool = true,
2497
add_self_loops::Bool = true,
@@ -55,7 +128,7 @@ end
55128

56129
function (l::GCNConv)(g, x, edge_weight, ps, st;
57130
norm_fn = d -> 1 ./ sqrt.(d),
58-
conv_weight=nothing, )
131+
conv_weight=nothing)
59132

60133
m = (; ps.weight, bias = _getbias(ps),
61134
l.add_self_loops, l.use_edge_weight, l.σ)

0 commit comments

Comments
 (0)