@@ -582,7 +582,74 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
582582 return y, stnew
583583end
584584
585+ @doc raw """
586+ EGNNConv((in, ein) => out; hidden_size=2in, residual=false)
587+ EGNNConv(in => out; hidden_size=2in, residual=false)
588+
589+ Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph
590+ Neural Networks](https://arxiv.org/abs/2102.09844).
591+
592+ The layer performs the following operation:
593+
594+ ```math
595+ \b egin{aligned}
596+ \m athbf{m}_{j\t o i} &=\p hi_e(\m athbf{h}_i, \m athbf{h}_j, \l Vert\m athbf{x}_i-\m athbf{x}_j\r Vert^2, \m athbf{e}_{j\t o i}),\\
597+ \m athbf{x}_i' &= \m athbf{x}_i + C_i\s um_{j\i n\m athcal{N}(i)}(\m athbf{x}_i-\m athbf{x}_j)\p hi_x(\m athbf{m}_{j\t o i}),\\
598+ \m athbf{m}_i &= C_i\s um_{j\i n\m athcal{N}(i)} \m athbf{m}_{j\t o i},\\
599+ \m athbf{h}_i' &= \m athbf{h}_i + \p hi_h(\m athbf{h}_i, \m athbf{m}_i)
600+ \e nd{aligned}
601+ ```
602+ where ``\m athbf{h}_i``, ``\m athbf{x}_i``, ``\m athbf{e}_{j\t o i}`` are invariant node features, equivariant node
603+ features, and edge features respectively. ``\p hi_e``, ``\p hi_h``, and
604+ ``\p hi_x`` are two-layer MLPs. `C` is a constant for normalization,
605+ computed as ``1/|\m athcal{N}(i)|``.
606+
607+
608+ # Constructor Arguments
609+
610+ - `in`: Number of input features for `h`.
611+ - `out`: Number of output features for `h`.
612+ - `ein`: Number of input edge features.
613+ - `hidden_size`: Hidden representation size.
614+ - `residual`: If `true`, add a residual connection. Only possible if `in == out`. Default `false`.
615+
616+ # Forward Pass
617+
618+ l(g, x, h, e=nothing, ps, st)
619+
620+ ## Forward Pass Arguments:
621+
622+ - `g` : The graph.
623+ - `x` : Matrix of equivariant node coordinates.
624+ - `h` : Matrix of invariant node features.
625+ - `e` : Matrix of invariant edge features. Default `nothing`.
626+ - `ps` : Parameters.
627+ - `st` : State.
628+
629+ Returns updated `h` and `x`.
630+
631+ # Examples
632+
633+ ```julia
634+ using GNNLux, Lux, Random
635+
636+ # initialize random number generator
637+ rng = Random.default_rng()
638+
639+ # create random graph
640+ g = rand_graph(rng, 10, 10)
641+ h = randn(rng, Float32, 5, g.num_nodes)
642+ x = randn(rng, Float32, 3, g.num_nodes)
643+
644+ egnn = EGNNConv(5 => 6, 10)
645+
646+ # setup layer
647+ ps, st = LuxCore.setup(rng, egnn)
585648
649+ # forward pass
650+ (hnew, xnew), st = egnn(g, h, x, ps, st)
651+ ```
652+ """
586653@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)}
587654 ϕe
588655 ϕx
0 commit comments