@@ -1083,6 +1083,7 @@ The input to the layer is a node feature array `x` of size `(num_features, num_n
1083
1083
edge pseudo-coordinate array `e` of size `(num_features, num_edges)`
1084
1084
The residual ``\m athbf{x}_i`` is added only if `residual=true` and the output size is the same
1085
1085
as the input size.
1086
+
1086
1087
# Arguments
1087
1088
1088
1089
- `in`: Number of input node features.
@@ -1298,3 +1299,131 @@ function Base.show(io::IO, l::SGConv)
1298
1299
l. k == 1 || print (io, " , " , l. k)
1299
1300
print (io, " )" )
1300
1301
end
1302
+
1303
+
1304
+
1305
+ @doc raw """
1306
+ EdgeConv((in, ein) => out, hidden_size)
1307
+ EdgeConv(in => out, hidden_size=2*in)
1308
+
1309
+ Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph
1310
+ Neural Networks](https://arxiv.org/abs/2102.09844).
1311
+
1312
+ The layer performs the following operation:
1313
+
1314
+ ```math
1315
+ \m athbf{m}_{j\t o i}=\p hi_e(\m athbf{h}_i, \m athbf{h}_j, \l Vert\m athbf{x}_i-\m athbf{x}_j\r Vert^2, \m athbf{e}_{j\t o i}),\\
1316
+ \m athbf{x}_i' = \m athbf{h}_i{x_i} + C_i\s um_{j\i n\m athcal{N}(i)}(\m athbf{x}_i-\m athbf{x}_j)\p hi_x(\m athbf{m}_{j\t o i}),\\
1317
+ \m athbf{m}_i = C_i\s um_{j\i n\m athcal{N}(i)} \m athbf{m}_{j\t o i},\\
1318
+ \m athbf{h}_i' = \m athbf{h}_i + \p hi_h(\m athbf{h}_i, \m athbf{m}_i)
1319
+ ```
1320
+ where ``h_i``, ``x_i``, ``e_{ij}`` are invariant node features, equivariance node
1321
+ features, and edge features respectively. ``\p hi_e``, ``\p hi_h``, and
1322
+ ``\p hi_x`` are two-layer MLPs. :math:`C` is a constant for normalization,
1323
+ computed as ``1/|\m athcal{N}(i)|``.
1324
+
1325
+
1326
+ # Constructor Arguments
1327
+
1328
+ - `in`: Number of input features for `h`.
1329
+ - `out`: Number of output features for `h`.
1330
+ - `ein`: Number of input edge features.
1331
+ - `hidden_size`: Hidden representation size.
1332
+ - `residual`: If `true`, add a residual connection. Only possible if `in == out`. Default `false`.
1333
+
1334
+ # Forward Pass
1335
+
1336
+ l(g, x, h, e=nothing)
1337
+
1338
+ ## Forward Pass Arguments:
1339
+
1340
+ - `g` : The graph.
1341
+ - `x` : Matrix of equivariant node coordinates.
1342
+ - `h` : Matrix of invariant node features.
1343
+ - `e` : Matrix of invariant edge features. Default `nothing`.
1344
+
1345
+ Returns updated `h` and `x`.
1346
+
1347
+ # Examples
1348
+
1349
+ ```julia
1350
+ g = rand_graph(10, 10)
1351
+ h = randn(Float32, 5, g.num_nodes)
1352
+ x = randn(Float32, 3, g.num_nodes)
1353
+ egnn = EGNNConv(5 => 6, 10)
1354
+ hnew, xnew = egnn(g, h, x)
1355
+ ```
1356
+ """
1357
+ struct EGNNConv <: GNNLayer
1358
+ ϕe:: Chain
1359
+ ϕx:: Chain
1360
+ ϕh:: Chain
1361
+ num_features:: NamedTuple
1362
+ residual:: Bool
1363
+ end
1364
+
1365
+ @functor EGNNConv
1366
+
1367
+ EGNNConv (ch:: Pair{Int,Int} , hidden_size= 2 * ch[1 ]) = EGNNConv ((ch[1 ], 0 ) => ch[2 ], hidden_size)
1368
+
1369
+ # Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
1370
+ function EGNNConv (ch:: Pair{NTuple{2, Int}, Int} , hidden_size:: Int , residual= false )
1371
+ (in_size, edge_feat_size), out_size = ch
1372
+ act_fn = swish
1373
+
1374
+ # +1 for the radial feature: ||x_i - x_j||^2
1375
+ ϕe = Chain (Dense (in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn),
1376
+ Dense (hidden_size => hidden_size, act_fn))
1377
+
1378
+ ϕh = Chain (Dense (in_size + hidden_size, hidden_size, swish),
1379
+ Dense (hidden_size, out_size))
1380
+
1381
+ ϕx = Chain (Dense (hidden_size, hidden_size, swish),
1382
+ Dense (hidden_size, 1 , bias= false ))
1383
+
1384
+ num_features = (in= in_size, edge= edge_feat_size, out= out_size)
1385
+ if residual
1386
+ @assert in_size == out_size " Residual connection only possible if in_size == out_size"
1387
+ end
1388
+ return EGNNConv (ϕe, ϕx, ϕh, num_features, residual)
1389
+ end
1390
+
1391
+ function (l:: EGNNConv )(g:: GNNGraph , h:: AbstractMatrix , x:: AbstractMatrix , e= nothing )
1392
+ if l. num_features. edge > 0
1393
+ @assert e != = nothing " Edge features must be provided."
1394
+ end
1395
+ @assert size (h, 1 ) == l. num_features. in " Input features must match layer input size."
1396
+
1397
+
1398
+ @show size (x) size (h)
1399
+
1400
+ function message (xi, xj, e)
1401
+ if l. num_features. edge > 0
1402
+ f = vcat (xi. h, xj. h, e. sqnorm_xdiff, e. e)
1403
+ else
1404
+ f = vcat (xi. h, xj. h, e. sqnorm_xdiff)
1405
+ end
1406
+
1407
+ msg_h = l. ϕe (f)
1408
+ msg_x = l. ϕx (msg_h) .* e. x_diff
1409
+ return (; x= msg_x, h= msg_h)
1410
+ end
1411
+
1412
+ x_diff = apply_edges (xi_sub_xj, g, x, x)
1413
+ sqnorm_xdiff = sum (x_diff .^ 2 , dims= 1 )
1414
+ x_diff = x_diff ./ (sqrt .(sqnorm_xdiff) .+ 1f-6 )
1415
+
1416
+ msg = apply_edges (message, g, xi= (; h), xj= (; h), e= (; e, x_diff, sqnorm_xdiff))
1417
+ h_aggr = aggregate_neighbors (g, + , msg. h)
1418
+ x_aggr = aggregate_neighbors (g, mean, msg. x)
1419
+
1420
+ hnew = l. ϕh (vcat (h, h_aggr))
1421
+ if l. residual
1422
+ h = h .+ hnew
1423
+ else
1424
+ h = hnew
1425
+ end
1426
+ x = x .+ x_aggr
1427
+
1428
+ return h, x
1429
+ end
0 commit comments