@@ -6,3 +6,67 @@ An abstract type from which graph neural network layers are derived.
6
6
See also [`GNNChain`](@ref).
7
7
"""
8
8
abstract type GNNLayer end
9
+
10
+ """
11
+ GNNChain(layers...)
12
+ GNNChain(name = layer, ...)
13
+
14
+ Collects multiple layers / functions to be called in sequence
15
+ on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`,
16
+ and if names are given, `m[:name] == m[1]` etc.
17
+
18
+ ## Examples
19
+
20
+ ```jldoctest
21
+ julia> m = GNNChain(x -> x^2, x -> x+1);
22
+
23
+ julia> m(5) == 26
24
+ true
25
+
26
+ julia> m = GNNChain(Dense(10, 5, tanh), Dense(5, 2));
27
+
28
+ julia> x = rand(10, 32);
29
+
30
+ julia> m(x) == m[2](m[1](x))
31
+ true
32
+
33
+ julia> m2 = GNNChain(enc = GNNChain(Flux.flatten, Dense(10, 5, tanh)),
34
+ dec = Dense(5, 2));
35
+
36
+ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
37
+ true
38
+ ```
39
+ """
40
+ struct GNNChain{T}
41
+ layers:: T
42
+
43
+ GNNChain (xs... ) = new {typeof(xs)} (xs)
44
+
45
+ function GNNChain (; kw... )
46
+ :layers in Base. keys (kw) && throw (ArgumentError (" a GNNChain cannot have a named layer called `layers`" ))
47
+ isempty (kw) && return new {Tuple{}} (())
48
+ new {typeof(values(kw))} (values (kw))
49
+ end
50
+ end
51
+
52
+ @forward GNNChain. layers Base. getindex, Base. length, Base. first, Base. last,
53
+ Base. iterate, Base. lastindex, Base. keys
54
+
55
+ functor (:: Type{<:GNNChain} , c) = c. layers, ls -> GNNChain (ls... )
56
+
57
+ applychain (:: Tuple{} , x) = x
58
+ applychain (fs:: Tuple , x) = applychain (tail (fs), first (fs)(x))
59
+
60
+ (c:: GNNChain )(x) = applychain (Tuple (c. layers), x)
61
+
62
+ Base. getindex (c:: GNNChain , i:: AbstractArray ) = GNNChain (c. layers[i]. .. )
63
+ Base. getindex (c:: GNNChain{<:NamedTuple} , i:: AbstractArray ) =
64
+ GNNChain (; NamedTuple {Base.keys(c)[i]} (Tuple (c. layers)[i])... )
65
+
66
+ function Base. show (io:: IO , c:: GNNChain )
67
+ print (io, " GNNChain(" )
68
+ _show_layers (io, c. layers)
69
+ print (io, " )" )
70
+ end
71
+ _show_layers (io, layers:: Tuple ) = join (io, layers, " , " )
72
+ _show_layers (io, layers:: NamedTuple ) = join (io, [" $k = $v " for (k, v) in pairs (layers)], " , " )
0 commit comments