Skip to content

Commit 831d41b

Browse files
committed
Add GMMConv docs
1 parent 6b0bf24 commit 831d41b

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,63 @@ function Base.show(io::IO, l::GINConv)
13391339
print(io, ")")
13401340
end
13411341

1342+
@doc raw"""
1343+
GMMConv((in, ein) => out, σ=identity; K = 1, residual = false init_weight = glorot_uniform, init_bias = zeros32, use_bias = true)
1344+
1345+
Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
1346+
Performs the operation
1347+
```math
1348+
\mathbf{x}_i' = \mathbf{x}_i + \frac{1}{|N(i)|} \sum_{j\in N(i)}\frac{1}{K}\sum_{k=1}^K \mathbf{w}_k(\mathbf{e}_{j\to i}) \odot \Theta_k \mathbf{x}_j
1349+
```
1350+
where ``w^a_{k}(e^a)`` for feature `a` and kernel `k` is given by
1351+
```math
1352+
w^a_{k}(e^a) = \exp(-\frac{1}{2}(e^a - \mu^a_k)^T (\Sigma^{-1})^a_k(e^a - \mu^a_k))
1353+
```
1354+
``\Theta_k, \mu^a_k, (\Sigma^{-1})^a_k`` are learnable parameters.
1355+
1356+
The input to the layer is a node feature array `x` of size `(num_features, num_nodes)` and
1357+
edge pseudo-coordinate array `e` of size `(num_features, num_edges)`
1358+
The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same
1359+
as the input size.
1360+
1361+
# Arguments
1362+
1363+
- `in`: Number of input node features.
1364+
- `ein`: Number of input edge features.
1365+
- `out`: Number of output features.
1366+
- `σ`: Activation function. Default `identity`.
1367+
- `K`: Number of kernels. Default `1`.
1368+
- `residual`: Residual conncetion. Default `false`.
1369+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
1370+
- `init_bias`: Bias initializer. Default `zeros32`.
1371+
- `use_bias`: Add learnable bias. Default `true`.
1372+
1373+
# Examples
1374+
1375+
```julia
1376+
using GNNLux, Lux, Random
1377+
1378+
# initialize random number generator
1379+
rng = Random.default_rng()
1380+
1381+
# create data
1382+
s = [1,1,2,3]
1383+
t = [2,3,1,1]
1384+
g = GNNGraph(s,t)
1385+
nin, ein, out, K = 4, 10, 7, 8
1386+
x = randn(rng, Float32, nin, g.num_nodes)
1387+
e = randn(rng, Float32, ein, g.num_edges)
1388+
1389+
# create layer
1390+
l = GMMConv((nin, ein) => out, K=K)
1391+
1392+
# setup layer
1393+
ps, st = LuxCore.setup(rng, l)
1394+
1395+
# forward pass
1396+
y, st = l(g, x, e, ps, st) # size: out × num_nodes
1397+
```
1398+
"""
13421399
@concrete struct GMMConv <: GNNLayer
13431400
σ
13441401
ch::Pair{NTuple{2, Int}, Int}

0 commit comments

Comments
 (0)