Skip to content

Commit 24a22aa

Browse files
Merge pull request #223 from CarloLucibello/cl/egnn
equivariant gnn
2 parents 377fcf1 + f228781 commit 24a22aa

File tree

5 files changed

+164
-5
lines changed

5 files changed

+164
-5
lines changed

docs/src/api/messagepassing.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ propagate
2525
copy_xi
2626
copy_xj
2727
xi_dot_xj
28+
xi_sub_xj
29+
xj_sub_xi
2830
e_mul_xj
2931
w_mul_xj
3032
```

src/GraphNeuralNetworks.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ export
3636
copy_xj,
3737
copy_xi,
3838
xi_dot_xj,
39+
xi_sub_xj,
40+
xj_sub_xi,
3941
e_mul_xj,
4042
w_mul_xj,
4143

@@ -50,17 +52,18 @@ export
5052
CGConv,
5153
ChebConv,
5254
EdgeConv,
55+
EGNNConv,
5356
GATConv,
5457
GATv2Conv,
5558
GatedGraphConv,
5659
GCNConv,
5760
GINConv,
61+
GMMConv,
5862
GraphConv,
5963
MEGNetConv,
6064
NNConv,
6165
ResGatedGraphConv,
6266
SAGEConv,
63-
GMMConv,
6467
SGConv,
6568

6669
# layers/pool

src/layers/conv.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,7 @@ The input to the layer is a node feature array `x` of size `(num_features, num_n
10831083
edge pseudo-coordinate array `e` of size `(num_features, num_edges)`
10841084
The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same
10851085
as the input size.
1086+
10861087
# Arguments
10871088
10881089
- `in`: Number of input node features.
@@ -1298,3 +1299,131 @@ function Base.show(io::IO, l::SGConv)
12981299
l.k == 1 || print(io, ", ", l.k)
12991300
print(io, ")")
13001301
end
1302+
1303+
1304+
1305+
@doc raw"""
1306+
EdgeConv((in, ein) => out, hidden_size)
1307+
EdgeConv(in => out, hidden_size=2*in)
1308+
1309+
Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph
1310+
Neural Networks](https://arxiv.org/abs/2102.09844).
1311+
1312+
The layer performs the following operation:
1313+
1314+
```math
1315+
\mathbf{m}_{j\to i}=\phi_e(\mathbf{h}_i, \mathbf{h}_j, \lVert\mathbf{x}_i-\mathbf{x}_j\rVert^2, \mathbf{e}_{j\to i}),\\
1316+
\mathbf{x}_i' = \mathbf{h}_i{x_i} + C_i\sum_{j\in\mathcal{N}(i)}(\mathbf{x}_i-\mathbf{x}_j)\phi_x(\mathbf{m}_{j\to i}),\\
1317+
\mathbf{m}_i = C_i\sum_{j\in\mathcal{N}(i)} \mathbf{m}_{j\to i},\\
1318+
\mathbf{h}_i' = \mathbf{h}_i + \phi_h(\mathbf{h}_i, \mathbf{m}_i)
1319+
```
1320+
where ``h_i``, ``x_i``, ``e_{ij}`` are invariant node features, equivariance node
1321+
features, and edge features respectively. ``\phi_e``, ``\phi_h``, and
1322+
``\phi_x`` are two-layer MLPs. :math:`C` is a constant for normalization,
1323+
computed as ``1/|\mathcal{N}(i)|``.
1324+
1325+
1326+
# Constructor Arguments
1327+
1328+
- `in`: Number of input features for `h`.
1329+
- `out`: Number of output features for `h`.
1330+
- `ein`: Number of input edge features.
1331+
- `hidden_size`: Hidden representation size.
1332+
- `residual`: If `true`, add a residual connection. Only possible if `in == out`. Default `false`.
1333+
1334+
# Forward Pass
1335+
1336+
l(g, x, h, e=nothing)
1337+
1338+
## Forward Pass Arguments:
1339+
1340+
- `g` : The graph.
1341+
- `x` : Matrix of equivariant node coordinates.
1342+
- `h` : Matrix of invariant node features.
1343+
- `e` : Matrix of invariant edge features. Default `nothing`.
1344+
1345+
Returns updated `h` and `x`.
1346+
1347+
# Examples
1348+
1349+
```julia
1350+
g = rand_graph(10, 10)
1351+
h = randn(Float32, 5, g.num_nodes)
1352+
x = randn(Float32, 3, g.num_nodes)
1353+
egnn = EGNNConv(5 => 6, 10)
1354+
hnew, xnew = egnn(g, h, x)
1355+
```
1356+
"""
1357+
struct EGNNConv <: GNNLayer
1358+
ϕe::Chain
1359+
ϕx::Chain
1360+
ϕh::Chain
1361+
num_features::NamedTuple
1362+
residual::Bool
1363+
end
1364+
1365+
@functor EGNNConv
1366+
1367+
EGNNConv(ch::Pair{Int,Int}, hidden_size=2*ch[1]) = EGNNConv((ch[1], 0) => ch[2], hidden_size)
1368+
1369+
#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
1370+
function EGNNConv(ch::Pair{NTuple{2, Int}, Int}, hidden_size::Int, residual=false)
1371+
(in_size, edge_feat_size), out_size = ch
1372+
act_fn = swish
1373+
1374+
# +1 for the radial feature: ||x_i - x_j||^2
1375+
ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn),
1376+
Dense(hidden_size => hidden_size, act_fn))
1377+
1378+
ϕh = Chain(Dense(in_size + hidden_size, hidden_size, swish),
1379+
Dense(hidden_size, out_size))
1380+
1381+
ϕx = Chain(Dense(hidden_size, hidden_size, swish),
1382+
Dense(hidden_size, 1, bias=false))
1383+
1384+
num_features = (in=in_size, edge=edge_feat_size, out=out_size)
1385+
if residual
1386+
@assert in_size == out_size "Residual connection only possible if in_size == out_size"
1387+
end
1388+
return EGNNConv(ϕe, ϕx, ϕh, num_features, residual)
1389+
end
1390+
1391+
function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e=nothing)
1392+
if l.num_features.edge > 0
1393+
@assert e !== nothing "Edge features must be provided."
1394+
end
1395+
@assert size(h, 1) == l.num_features.in "Input features must match layer input size."
1396+
1397+
1398+
@show size(x) size(h)
1399+
1400+
function message(xi, xj, e)
1401+
if l.num_features.edge > 0
1402+
f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e)
1403+
else
1404+
f = vcat(xi.h, xj.h, e.sqnorm_xdiff)
1405+
end
1406+
1407+
msg_h = l.ϕe(f)
1408+
msg_x = l.ϕx(msg_h) .* e.x_diff
1409+
return (; x=msg_x, h=msg_h)
1410+
end
1411+
1412+
x_diff = apply_edges(xi_sub_xj, g, x, x)
1413+
sqnorm_xdiff = sum(x_diff .^ 2, dims=1)
1414+
x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1f-6)
1415+
1416+
msg = apply_edges(message, g, xi=(; h), xj=(; h), e=(; e, x_diff, sqnorm_xdiff))
1417+
h_aggr = aggregate_neighbors(g, +, msg.h)
1418+
x_aggr = aggregate_neighbors(g, mean, msg.x)
1419+
1420+
hnew = l.ϕh(vcat(h, h_aggr))
1421+
if l.residual
1422+
h = h .+ hnew
1423+
else
1424+
h = hnew
1425+
end
1426+
x = x .+ x_aggr
1427+
1428+
return h, x
1429+
end

src/msgpass.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""
2-
propagate(f, g, aggr; xi, xj, e) -> m̄
2+
propagate(f, g, aggr; [xi, xj, e]) -> m̄
3+
propagate(f, g, aggr, xi, xj, e=nothing)
34
45
Performs message passing on graph `g`. Takes care of materializing the node features on each edge,
56
applying the message function, and returning an aggregated message ``\\bar{\\mathbf{m}}``
@@ -68,7 +69,7 @@ function propagate end
6869
propagate(l, g::GNNGraph, aggr; xi=nothing, xj=nothing, e=nothing) =
6970
propagate(l, g, aggr, xi, xj, e)
7071

71-
function propagate(l, g::GNNGraph, aggr, xi, xj, e)
72+
function propagate(l, g::GNNGraph, aggr, xi, xj, e=nothing)
7273
m = apply_edges(l, g, xi, xj, e)
7374
= aggregate_neighbors(g, aggr, m)
7475
return
@@ -77,8 +78,8 @@ end
7778
## APPLY EDGES
7879

7980
"""
80-
apply_edges(f, g, xi, xj, e)
8181
apply_edges(f, g; [xi, xj, e])
82+
apply_edges(f, g, xi, xj, e=nothing)
8283
8384
Returns the message from node `j` to node `i` .
8485
In the message-passing scheme, the incoming messages
@@ -110,7 +111,7 @@ function apply_edges end
110111
apply_edges(l, g::GNNGraph; xi=nothing, xj=nothing, e=nothing) =
111112
apply_edges(l, g, xi, xj, e)
112113

113-
function apply_edges(f, g::GNNGraph, xi, xj, e)
114+
function apply_edges(f, g::GNNGraph, xi, xj, e=nothing)
114115
check_num_nodes(g, xi)
115116
check_num_nodes(g, xj)
116117
check_num_edges(g, e)
@@ -158,6 +159,17 @@ copy_xi(xi, xj, e) = xi
158159
"""
159160
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)
160161

162+
"""
163+
xi_sub_xj(xi, xj, e) = xi .- xj
164+
"""
165+
xi_sub_xj(xi, xj, e) = xi .- xj
166+
167+
"""
168+
xj_sub_xi(xi, xj, e) = xj .- xi
169+
"""
170+
xj_sub_xi(xi, xj, e) = xj .- xi
171+
172+
161173
"""
162174
e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj
163175

test/layers/conv.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,17 @@
288288
end
289289
end
290290
end
291+
292+
@testset "EGNNConv" begin
293+
hin = 5
294+
hout = 5
295+
hidden = 5
296+
l = EGNNConv(hin => hout, hidden)
297+
g = rand_graph(10, 20, graph_type=GRAPH_T)
298+
x = rand(T, in_channel, g.num_nodes)
299+
h = randn(T, hin, g.num_nodes)
300+
hnew, xnew = l(g, h, x)
301+
@test size(hnew) == (hout, g.num_nodes)
302+
@test size(xnew) == (in_channel, g.num_nodes)
303+
end
291304
end

0 commit comments

Comments
 (0)