@@ -356,20 +356,20 @@ end
356
356
357
357
358
358
@doc raw """
359
- EdgeConv(f ; aggr=max)
359
+ EdgeConv(nn ; aggr=max)
360
360
361
361
Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829).
362
362
363
363
Performs the operation
364
364
```math
365
- \m athbf{x}_i' = \s quare_{j \i n N(i)} f (\m athbf{x}_i || \m athbf{x}_j - \m athbf{x}_i)
365
+ \m athbf{x}_i' = \s quare_{j \i n N(i)} nn (\m athbf{x}_i || \m athbf{x}_j - \m athbf{x}_i)
366
366
```
367
367
368
- where `f` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
368
+ where `nn` generally denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
369
369
370
370
# Arguments
371
371
372
- - `f `: A (possibly learnable) function acting on edge features.
372
+ - `nn `: A (possibly learnable) function acting on edge features.
373
373
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
374
374
"""
375
375
struct EdgeConv <: GNNLayer
@@ -405,13 +405,13 @@ Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural
405
405
406
406
407
407
```math
408
- \m athbf{x}_i' = f \l eft((1 + \e psilon) \m athbf{x}_i + \s um_{j \i n N(i)} \m athbf{x}_j \r ight)
408
+ \m athbf{x}_i' = f_ \T heta \l eft((1 + \e psilon) \m athbf{x}_i + \s um_{j \i n N(i)} \m athbf{x}_j \r ight)
409
409
```
410
- where `f ` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
410
+ where ``f_ \T heta` ` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
411
411
412
412
# Arguments
413
413
414
- - `f `: A (possibly learnable) function acting on node features.
414
+ - ``f` `: A (possibly learnable) function acting on node features.
415
415
- `eps`: Weighting factor.
416
416
"""
417
417
struct GINConv{R<: Real } <: GNNLayer
@@ -434,3 +434,69 @@ function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
434
434
X, _ = propagate (l, g, + , X)
435
435
X
436
436
end
437
+
438
+
439
+ @doc raw """
440
+ NNConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
441
+
442
+ The continuous kernel-based convolutional operator from the
443
+ [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper.
444
+ This convolution is also known as the edge-conditioned convolution from the
445
+ [Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper.
446
+
447
+ Performs the operation
448
+
449
+ ```math
450
+ \m athbf{x}_i' = W x_i + \s quare_{j \i n N(i)} f_\T heta(\m athbf{e}_{j\t o i})\,\m athbf{x}_j
451
+ ```
452
+
453
+ where ``f_\T heta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron).
454
+ Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`,
455
+ the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`.
456
+ For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed.
457
+
458
+ # Arguments
459
+
460
+ - `in`: The dimension of input features.
461
+ - `out`: The dimension of output features.
462
+ - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
463
+ - `σ`: Activation function.
464
+ - `bias`: Add learnable bias.
465
+ - `init`: Weights' initializer.
466
+ """
467
+ struct NNConv <: GNNLayer
468
+ weight
469
+ bias
470
+ nn
471
+ aggr
472
+ end
473
+
474
+ @functor NNConv
475
+
476
+ function NNConv (ch:: Pair{Int,Int} , σ= identity; aggr= + , bias= true , init= glorot_uniform)
477
+ in, out = ch
478
+ W = init (out, in)
479
+ b = Flux. create_bias (W, bias, out)
480
+ return NNConv (W, b, nn, aggr)
481
+ end
482
+
483
+ function compute_message (l:: NNConv , x_i, x_j, e_ij)
484
+ nin, nedges = size (x_i)
485
+ W = reshape (l. nn (e_ij), (:, nin, nedges))
486
+ return NNlib. batched_mul (W, x_j)
487
+ end
488
+
489
+ update_node (l:: NNConv , m, x) = l. weight* x + m
490
+
491
+ function (l:: NNConv )(g:: GNNGraph , x:: AbstractMatrix , e)
492
+ check_num_nodes (g, X)
493
+ x, _ = propagate (l, g, l. aggr, x, e)
494
+ return l. σ .(x + l. bias)
495
+ end
496
+
497
+ function Base. show (io:: IO , l:: NNConv )
498
+ out, in = size (l. weight)
499
+ print (io, " NNConv( $in => $out " )
500
+ print (io, " , aggr=" , l. aggr)
501
+ print (io, " )" )
502
+ end
0 commit comments