Skip to content

Commit f40e050

Browse files
docstring
1 parent 43338e5 commit f40e050

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

src/layers/basic.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,31 @@ abstract type GNNLayer end
1212
GNNChain(name = layer, ...)
1313
1414
Collects multiple layers / functions to be called in sequence
15-
on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`,
16-
and if names are given, `m[:name] == m[1]` etc.
17-
18-
## Examples
15+
on given input graph and input node features.
1916
20-
```
21-
julia> m = GNNChain(x -> x^2, x -> x+1);
17+
It allows to compose layers in a sequential fashion as `Flux.Chain`
18+
does, propagating the output of each layer to the next one.
19+
In addition, `GNNChain` handles the input graph as well, providing it
20+
as a first argument only to layers subtyping the [`GNNLayer`](@ref) abstract type.
2221
23-
julia> m(5) == 26
24-
true
22+
`GNNChain` supports indexing and slicing, `m[2]` or `m[1:end-1]`,
23+
and if names are given, `m[:name] == m[1]` etc.
2524
26-
julia> m = GNNChain(Dense(10, 5, tanh), Dense(5, 2));
25+
# Examples
2726
28-
julia> x = rand(10, 32);
27+
```juliarepl
28+
julia> m = GNNChain(GCNConv(2=>5), BatchNorm(5), x -> relu.(x), Dense(5, 4));
2929
30-
julia> m(x) == m[2](m[1](x))
31-
true
30+
julia> x = randn(Float32, 2, 3);
3231
33-
julia> m2 = GNNChain(enc = GNNChain(Flux.flatten, Dense(10, 5, tanh)),
34-
dec = Dense(5, 2));
32+
julia> g = GNNGraph([1,1,2,3], [2,3,1,1]);
3533
36-
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
37-
true
34+
julia> m(g, x)
35+
4×3 Matrix{Float32}:
36+
0.157941 0.15443 0.193471
37+
0.0819516 0.0503105 0.122523
38+
0.225933 0.267901 0.241878
39+
-0.0134364 -0.0120716 -0.0172505
3840
```
3941
"""
4042
struct GNNChain{T}

0 commit comments

Comments
 (0)