Skip to content

Commit 9944798

Browse files
committed
Add EGNNConv docstring
1 parent 5dc4cf6 commit 9944798

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,74 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
582582
return y, stnew
583583
end
584584

585+
@doc raw"""
586+
EGNNConv((in, ein) => out; hidden_size=2in, residual=false)
587+
EGNNConv(in => out; hidden_size=2in, residual=false)
588+
589+
Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph
590+
Neural Networks](https://arxiv.org/abs/2102.09844).
591+
592+
The layer performs the following operation:
593+
594+
```math
595+
\begin{aligned}
596+
\mathbf{m}_{j\to i} &=\phi_e(\mathbf{h}_i, \mathbf{h}_j, \lVert\mathbf{x}_i-\mathbf{x}_j\rVert^2, \mathbf{e}_{j\to i}),\\
597+
\mathbf{x}_i' &= \mathbf{x}_i + C_i\sum_{j\in\mathcal{N}(i)}(\mathbf{x}_i-\mathbf{x}_j)\phi_x(\mathbf{m}_{j\to i}),\\
598+
\mathbf{m}_i &= C_i\sum_{j\in\mathcal{N}(i)} \mathbf{m}_{j\to i},\\
599+
\mathbf{h}_i' &= \mathbf{h}_i + \phi_h(\mathbf{h}_i, \mathbf{m}_i)
600+
\end{aligned}
601+
```
602+
where ``\mathbf{h}_i``, ``\mathbf{x}_i``, ``\mathbf{e}_{j\to i}`` are invariant node features, equivariant node
603+
features, and edge features respectively. ``\phi_e``, ``\phi_h``, and
604+
``\phi_x`` are two-layer MLPs. `C` is a constant for normalization,
605+
computed as ``1/|\mathcal{N}(i)|``.
606+
607+
608+
# Constructor Arguments
609+
610+
- `in`: Number of input features for `h`.
611+
- `out`: Number of output features for `h`.
612+
- `ein`: Number of input edge features.
613+
- `hidden_size`: Hidden representation size.
614+
- `residual`: If `true`, add a residual connection. Only possible if `in == out`. Default `false`.
615+
616+
# Forward Pass
617+
618+
l(g, x, h, e=nothing, ps, st)
619+
620+
## Forward Pass Arguments:
621+
622+
- `g` : The graph.
623+
- `x` : Matrix of equivariant node coordinates.
624+
- `h` : Matrix of invariant node features.
625+
- `e` : Matrix of invariant edge features. Default `nothing`.
626+
- `ps` : Parameters.
627+
- `st` : State.
628+
629+
Returns updated `h` and `x`.
630+
631+
# Examples
632+
633+
```julia
634+
using GNNLux, Lux, Random
635+
636+
# initialize random number generator
637+
rng = Random.default_rng()
638+
639+
# create random graph
640+
g = rand_graph(rng, 10, 10)
641+
h = randn(rng, Float32, 5, g.num_nodes)
642+
x = randn(rng, Float32, 3, g.num_nodes)
643+
644+
egnn = EGNNConv(5 => 6, 10)
645+
646+
# setup layer
647+
ps, st = LuxCore.setup(rng, egnn)
585648
649+
# forward pass
650+
(hnew, xnew), st = egnn(g, h, x, ps, st)
651+
```
652+
"""
586653
@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)}
587654
ϕe
588655
ϕx

0 commit comments

Comments
 (0)