Skip to content

Commit b3d7462

Browse files
committed
ResGatedGraphConv docs
1 parent 56e5eb6 commit b3d7462

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,58 @@ function Base.show(io::IO, l::NNConv)
16671667
print(io, ")")
16681668
end
16691669

1670+
@doc raw"""
1671+
ResGatedGraphConv(in => out, act=identity; init_weight = glorot_uniform, init_bias = zeros32, use_bias = true)
1672+
1673+
The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets](https://arxiv.org/abs/1711.07553) paper.
1674+
1675+
The layer's forward pass is given by
1676+
1677+
```math
1678+
\mathbf{x}_i' = act\big(U\mathbf{x}_i + \sum_{j \in N(i)} \eta_{ij} V \mathbf{x}_j\big),
1679+
```
1680+
where the edge gates ``\eta_{ij}`` are given by
1681+
1682+
```math
1683+
\eta_{ij} = sigmoid(A\mathbf{x}_i + B\mathbf{x}_j).
1684+
```
1685+
1686+
# Arguments
1687+
1688+
- `in`: The dimension of input features.
1689+
- `out`: The dimension of output features.
1690+
- `act`: Activation function.
1691+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
1692+
- `init_bias`: Bias initializer. Default `zeros32`.
1693+
- `use_bias`: Add learnable bias. Default `true`.
1694+
1695+
1696+
# Examples:
1697+
1698+
```julia
1699+
using GNNLux, Lux, Random
1700+
1701+
# initialize random number generator
1702+
rng = Random.default_rng()
1703+
1704+
# create data
1705+
s = [1,1,2,3]
1706+
t = [2,3,1,1]
1707+
in_channel = 3
1708+
out_channel = 5
1709+
g = GNNGraph(s, t)
1710+
x = randn(rng, Float32, in_channel, g.num_nodes)
1711+
1712+
# create layer
1713+
l = ResGatedGraphConv(in_channel => out_channel, tanh, use_bias = true)
1714+
1715+
# setup layer
1716+
ps, st = LuxCore.setup(rng, l)
1717+
1718+
# forward pass
1719+
y, st = l(g, x, ps, st) # size: out_channel × num_nodes
1720+
```
1721+
"""
16701722
@concrete struct ResGatedGraphConv <: GNNLayer
16711723
in_dims::Int
16721724
out_dims::Int

0 commit comments

Comments
 (0)