@@ -12,29 +12,31 @@ abstract type GNNLayer end
12
12
GNNChain(name = layer, ...)
13
13
14
14
Collects multiple layers / functions to be called in sequence
15
- on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`,
16
- and if names are given, `m[:name] == m[1]` etc.
17
-
18
- ## Examples
15
+ on given input graph and input node features.
19
16
20
- ```
21
- julia> m = GNNChain(x -> x^2, x -> x+1);
17
+ It allows to compose layers in a sequential fashion as `Flux.Chain`
18
+ does, propagating the output of each layer to the next one.
19
+ In addition, `GNNChain` handles the input graph as well, providing it
20
+ as a first argument only to layers subtyping the [`GNNLayer`](@ref) abstract type.
22
21
23
- julia> m(5) == 26
24
- true
22
+ `GNNChain` supports indexing and slicing, `m[2]` or `m[1:end-1]`,
23
+ and if names are given, `m[:name] == m[1]` etc.
25
24
26
- julia> m = GNNChain(Dense(10, 5, tanh), Dense(5, 2));
25
+ # Examples
27
26
28
- julia> x = rand(10, 32);
27
+ ```juliarepl
28
+ julia> m = GNNChain(GCNConv(2=>5), BatchNorm(5), x -> relu.(x), Dense(5, 4));
29
29
30
- julia> m(x) == m[2](m[1](x))
31
- true
30
+ julia> x = randn(Float32, 2, 3);
32
31
33
- julia> m2 = GNNChain(enc = GNNChain(Flux.flatten, Dense(10, 5, tanh)),
34
- dec = Dense(5, 2));
32
+ julia> g = GNNGraph([1,1,2,3], [2,3,1,1]);
35
33
36
- julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
37
- true
34
+ julia> m(g, x)
35
+ 4×3 Matrix{Float32}:
36
+ 0.157941 0.15443 0.193471
37
+ 0.0819516 0.0503105 0.122523
38
+ 0.225933 0.267901 0.241878
39
+ -0.0134364 -0.0120716 -0.0172505
38
40
```
39
41
"""
40
42
struct GNNChain{T}
0 commit comments