@@ -1458,6 +1458,50 @@ function Base.show(io::IO, l::GMMConv)
14581458 print (io, " )" )
14591459end
14601460
1461+ @doc raw """
1462+ MEGNetConv(ϕe, ϕv; aggr=mean)
1463+ MEGNetConv(in => out; aggr=mean)
1464+
1465+ Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf)
1466+ paper. In the forward pass, takes as inputs node features `x` and edge features `e` and returns
1467+ updated features `x'` and `e'` according to
1468+
1469+ ```math
1470+ \b egin{aligned}
1471+ \m athbf{e}_{i\t o j}' = \p hi_e([\m athbf{x}_i;\, \m athbf{x}_j;\, \m athbf{e}_{i\t o j}]),\\
1472+ \m athbf{x}_{i}' = \p hi_v([\m athbf{x}_i;\, \s quare_{j\i n \m athcal{N}(i)}\,\m athbf{e}_{j\t o i}']).
1473+ \e nd{aligned}
1474+ ```
1475+
1476+ `aggr` defines the aggregation to be performed.
1477+
1478+ If the neural networks `ϕe` and `ϕv` are not provided, they will be constructed from
1479+ the `in` and `out` arguments instead as multi-layer perceptron with one hidden layer and `relu`
1480+ activations.
1481+
1482+ # Examples
1483+
1484+ ```julia
1485+ using GNNLux, Lux, Random
1486+
1487+ # initialize random number generator
1488+ rng = Random.default_rng()
1489+
1490+ # create a random graph
1491+ g = rand_graph(rng, 10, 30)
1492+ x = randn(rng, Float32, 3, 10)
1493+ e = randn(rng, Float32, 3, 30)
1494+
1495+ # create a MEGNetConv layer
1496+ m = MEGNetConv(3 => 3)
1497+
1498+ # setup layer
1499+ ps, st = LuxCore.setup(rng, m)
1500+
1501+ # forward pass
1502+ (x′, e′), st = m(g, x, e, ps, st)
1503+ ```
1504+ """
14611505@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
14621506 in_dims:: Int
14631507 out_dims:: Int
0 commit comments