Skip to content

Commit 60985a3

Browse files
committed
Add GatedGraphConv docs
1 parent 866d20c commit 60985a3

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,54 @@ function (l::SGConv)(g, x, edge_weight, ps, st)
11861186
return y, st
11871187
end
11881188

1189+
@doc raw"""
1190+
GatedGraphConv(out, num_layers;
1191+
aggr = +, init_weight = glorot_uniform)
1192+
1193+
Gated graph convolution layer from [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493).
1194+
1195+
Implements the recursion
1196+
```math
1197+
\begin{aligned}
1198+
\mathbf{h}^{(0)}_i &= [\mathbf{x}_i; \mathbf{0}] \\
1199+
\mathbf{h}^{(l)}_i &= GRU(\mathbf{h}^{(l-1)}_i, \square_{j \in N(i)} W \mathbf{h}^{(l-1)}_j)
1200+
\end{aligned}
1201+
```
1202+
1203+
where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing through GRU. The dimension of input ``\mathbf{x}_i`` needs to be less or equal to `out`.
1204+
1205+
# Arguments
1206+
1207+
- `out`: The dimension of output features.
1208+
- `num_layers`: The number of recursion steps.
1209+
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
1210+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
1211+
1212+
# Examples:
1213+
1214+
```julia
1215+
using GNNLux, Lux, Random
1216+
1217+
# initialize random number generator
1218+
rng = Random.default_rng()
1219+
1220+
# create data
1221+
s = [1,1,2,3]
1222+
t = [2,3,1,1]
1223+
out_channel = 5
1224+
num_layers = 3
1225+
g = GNNGraph(s, t)
1226+
1227+
# create layer
1228+
l = GatedGraphConv(out_channel, num_layers)
1229+
1230+
# setup layer
1231+
ps, st = LuxCore.setup(rng, l)
1232+
1233+
# forward pass
1234+
y, st = l(g, x, ps, st) # size: out_channel × num_nodes
1235+
```
1236+
"""
11891237
@concrete struct GatedGraphConv <: GNNLayer
11901238
gru
11911239
init_weight
@@ -1194,7 +1242,6 @@ end
11941242
aggr
11951243
end
11961244

1197-
11981245
function GatedGraphConv(dims::Int, num_layers::Int;
11991246
aggr = +, init_weight = glorot_uniform)
12001247
gru = GRUCell(dims => dims)

0 commit comments

Comments
 (0)