@@ -1667,6 +1667,58 @@ function Base.show(io::IO, l::NNConv)
16671667 print (io, " )" )
16681668end
16691669
1670+ @doc raw """
1671+ ResGatedGraphConv(in => out, act=identity; init_weight = glorot_uniform, init_bias = zeros32, use_bias = true)
1672+
1673+ The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets](https://arxiv.org/abs/1711.07553) paper.
1674+
1675+ The layer's forward pass is given by
1676+
1677+ ```math
1678+ \m athbf{x}_i' = act\b ig(U\m athbf{x}_i + \s um_{j \i n N(i)} \e ta_{ij} V \m athbf{x}_j\b ig),
1679+ ```
1680+ where the edge gates ``\e ta_{ij}`` are given by
1681+
1682+ ```math
1683+ \e ta_{ij} = sigmoid(A\m athbf{x}_i + B\m athbf{x}_j).
1684+ ```
1685+
1686+ # Arguments
1687+
1688+ - `in`: The dimension of input features.
1689+ - `out`: The dimension of output features.
1690+ - `act`: Activation function.
1691+ - `init_weight`: Weights' initializer. Default `glorot_uniform`.
1692+ - `init_bias`: Bias initializer. Default `zeros32`.
1693+ - `use_bias`: Add learnable bias. Default `true`.
1694+
1695+
1696+ # Examples:
1697+
1698+ ```julia
1699+ using GNNLux, Lux, Random
1700+
1701+ # initialize random number generator
1702+ rng = Random.default_rng()
1703+
1704+ # create data
1705+ s = [1,1,2,3]
1706+ t = [2,3,1,1]
1707+ in_channel = 3
1708+ out_channel = 5
1709+ g = GNNGraph(s, t)
1710+ x = randn(rng, Float32, in_channel, g.num_nodes)
1711+
1712+ # create layer
1713+ l = ResGatedGraphConv(in_channel => out_channel, tanh, use_bias = true)
1714+
1715+ # setup layer
1716+ ps, st = LuxCore.setup(rng, l)
1717+
1718+ # forward pass
1719+ y, st = l(g, x, ps, st) # size: out_channel × num_nodes
1720+ ```
1721+ """
16701722@concrete struct ResGatedGraphConv <: GNNLayer
16711723 in_dims:: Int
16721724 out_dims:: Int
0 commit comments