@@ -1544,6 +1544,68 @@ function Base.show(io::IO, l::MEGNetConv)
15441544 print (io, " )" )
15451545end
15461546
1547+ @doc raw """
1548+ NNConv(in => out, f, σ=identity; aggr=+, init_bias = zeros32, use_bias = true, init_weight = glorot_uniform)
1549+
1550+ The continuous kernel-based convolutional operator from the
1551+ [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper.
1552+ This convolution is also known as the edge-conditioned convolution from the
1553+ [Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper.
1554+
1555+ Performs the operation
1556+
1557+ ```math
1558+ \m athbf{x}_i' = W \m athbf{x}_i + \s quare_{j \i n N(i)} f_\T heta(\m athbf{e}_{j\t o i})\,\m athbf{x}_j
1559+ ```
1560+
1561+ where ``f_\T heta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron).
1562+ Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`,
1563+ the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`.
1564+ For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed.
1565+
1566+ # Arguments
1567+
1568+ - `in`: The dimension of input node features.
1569+ - `out`: The dimension of output node features.
1570+ - `f`: A (possibly learnable) function acting on edge features.
1571+ - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
1572+ - `σ`: Activation function.
1573+ - `init_weight`: Weights' initializer. Default `glorot_uniform`.
1574+ - `init_bias`: Bias initializer. Default `zeros32`.
1575+ - `use_bias`: Add learnable bias. Default `true`.
1576+
1577+ # Examples:
1578+
1579+ ```julia
1580+ using GNNLux, Lux, Random
1581+
1582+ # initialize random number generator
1583+ rng = Random.default_rng()
1584+
1585+ # create data
1586+ n_in = 3
1587+ n_in_edge = 10
1588+ n_out = 5
1589+
1590+ s = [1,1,2,3]
1591+ t = [2,3,1,1]
1592+ g = GNNGraph(s, t)
1593+ x = randn(rng, Float32, n_in, g.num_nodes)
1594+ e = randn(rng, Float32, n_in_edge, g.num_edges)
1595+
1596+ # create dense layer
1597+ nn = Dense(n_in_edge => n_out * n_in)
1598+
1599+ # create layer
1600+ l = NNConv(n_in => n_out, nn, tanh, use_bias = true, aggr = +)
1601+
1602+ # setup layer
1603+ ps, st = LuxCore.setup(rng, l)
1604+
1605+ # forward pass
1606+ y, st = l(g, x, e, ps, st) # size: n_out × num_nodes
1607+ ```
1608+ """
15471609@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
15481610 nn <: AbstractLuxLayer
15491611 aggr
0 commit comments