Skip to content

Commit 82a7450

Browse files
fix NNConv docs (#488)
* fix nnconv docstring * cleanup
1 parent cb82352 commit 82a7450

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

src/layers/conv.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,8 @@ For convenience, also functions returning a single `(out*in, num_edges)` matrix
659659
660660
# Arguments
661661
662-
- `in`: The dimension of input features.
663-
- `out`: The dimension of output features.
662+
- `in`: The dimension of input node features.
663+
- `out`: The dimension of output node features.
664664
- `f`: A (possibly learnable) function acting on edge features.
665665
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
666666
- `σ`: Activation function.
@@ -670,22 +670,26 @@ For convenience, also functions returning a single `(out*in, num_edges)` matrix
670670
# Examples:
671671
672672
```julia
673+
n_in = 3
674+
n_in_edge = 10
675+
n_out = 5
676+
673677
# create data
674678
s = [1,1,2,3]
675679
t = [2,3,1,1]
676-
in_channel = 3
677-
out_channel = 5
678-
edim = 10
679680
g = GNNGraph(s, t)
680681
681682
# create dense layer
682-
nn = Dense(edim => out_channel * in_channel)
683+
nn = Dense(n_in_edge => n_out * n_in)
683684
684685
# create layer
685-
l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +)
686+
l = NNConv(n_in => n_out, nn, tanh, bias = true, aggr = +)
687+
688+
x = randn(Float32, n_in, g.num_nodes)
689+
e = randn(Float32, n_in_edge, g.num_edges)
686690
687691
# forward pass
688-
y = l(g, x)
692+
y = l(g, x, e)
689693
```
690694
"""
691695
struct NNConv{W, B, NN, F, A} <: GNNLayer

0 commit comments

Comments
 (0)