@@ -659,8 +659,8 @@ For convenience, also functions returning a single `(out*in, num_edges)` matrix
659
659
660
660
# Arguments
661
661
662
- - `in`: The dimension of input features.
663
- - `out`: The dimension of output features.
662
+ - `in`: The dimension of input node features.
663
+ - `out`: The dimension of output node features.
664
664
- `f`: A (possibly learnable) function acting on edge features.
665
665
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
666
666
- `σ`: Activation function.
@@ -670,22 +670,26 @@ For convenience, also functions returning a single `(out*in, num_edges)` matrix
670
670
# Examples:
671
671
672
672
```julia
673
+ n_in = 3
674
+ n_in_edge = 10
675
+ n_out = 5
676
+
673
677
# create data
674
678
s = [1,1,2,3]
675
679
t = [2,3,1,1]
676
- in_channel = 3
677
- out_channel = 5
678
- edim = 10
679
680
g = GNNGraph(s, t)
680
681
681
682
# create dense layer
682
- nn = Dense(edim => out_channel * in_channel )
683
+ nn = Dense(n_in_edge => n_out * n_in )
683
684
684
685
# create layer
685
- l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +)
686
+ l = NNConv(n_in => n_out, nn, tanh, bias = true, aggr = +)
687
+
688
+ x = randn(Float32, n_in, g.num_nodes)
689
+ e = randn(Float32, n_in_edge, g.num_edges)
686
690
687
691
# forward pass
688
- y = l(g, x)
692
+ y = l(g, x, e)
689
693
```
690
694
"""
691
695
struct NNConv{W, B, NN, F, A} <: GNNLayer
0 commit comments