Skip to content

Commit 56e5eb6

Browse files
committed
Add NNConv
1 parent 9ed8049 commit 56e5eb6

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,68 @@ function Base.show(io::IO, l::MEGNetConv)
15441544
print(io, ")")
15451545
end
15461546

1547+
@doc raw"""
1548+
NNConv(in => out, f, σ=identity; aggr=+, init_bias = zeros32, use_bias = true, init_weight = glorot_uniform)
1549+
1550+
The continuous kernel-based convolutional operator from the
1551+
[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper.
1552+
This convolution is also known as the edge-conditioned convolution from the
1553+
[Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper.
1554+
1555+
Performs the operation
1556+
1557+
```math
1558+
\mathbf{x}_i' = W \mathbf{x}_i + \square_{j \in N(i)} f_\Theta(\mathbf{e}_{j\to i})\,\mathbf{x}_j
1559+
```
1560+
1561+
where ``f_\Theta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron).
1562+
Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`,
1563+
the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`.
1564+
For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed.
1565+
1566+
# Arguments
1567+
1568+
- `in`: The dimension of input node features.
1569+
- `out`: The dimension of output node features.
1570+
- `f`: A (possibly learnable) function acting on edge features.
1571+
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
1572+
- `σ`: Activation function.
1573+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
1574+
- `init_bias`: Bias initializer. Default `zeros32`.
1575+
- `use_bias`: Add learnable bias. Default `true`.
1576+
1577+
# Examples:
1578+
1579+
```julia
1580+
using GNNLux, Lux, Random
1581+
1582+
# initialize random number generator
1583+
rng = Random.default_rng()
1584+
1585+
# create data
1586+
n_in = 3
1587+
n_in_edge = 10
1588+
n_out = 5
1589+
1590+
s = [1,1,2,3]
1591+
t = [2,3,1,1]
1592+
g = GNNGraph(s, t)
1593+
x = randn(rng, Float32, n_in, g.num_nodes)
1594+
e = randn(rng, Float32, n_in_edge, g.num_edges)
1595+
1596+
# create dense layer
1597+
nn = Dense(n_in_edge => n_out * n_in)
1598+
1599+
# create layer
1600+
l = NNConv(n_in => n_out, nn, tanh, use_bias = true, aggr = +)
1601+
1602+
# setup layer
1603+
ps, st = LuxCore.setup(rng, l)
1604+
1605+
# forward pass
1606+
y, st = l(g, x, e, ps, st) # size: n_out × num_nodes
1607+
```
1608+
"""
15471609
@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
15481610
nn <: AbstractLuxLayer
15491611
aggr

0 commit comments

Comments
 (0)