|  | 
| 2 | 2 |     abstract type GNNLayer <: AbstractLuxLayer end | 
| 3 | 3 | 
 | 
| 4 | 4 | An abstract type from which graph neural network layers are derived. | 
| 5 |  | -It is Derived from Lux's `AbstractLuxLayer` type. | 
|  | 5 | +It is derived from Lux's `AbstractLuxLayer` type. | 
| 6 | 6 | 
 | 
| 7 |  | -See also `GNNChain`. | 
|  | 7 | +See also [`GNNLux.GNNChain`](@ref). | 
| 8 | 8 | """ | 
| 9 | 9 | abstract type GNNLayer <: AbstractLuxLayer end | 
| 10 | 10 | 
 | 
| 11 | 11 | abstract type GNNContainerLayer{T} <: AbstractLuxContainerLayer{T} end | 
| 12 | 12 | 
 | 
|  | 13 | +""" | 
|  | 14 | +    GNNChain(layers...) | 
|  | 15 | +    GNNChain(name = layer, ...) | 
|  | 16 | +
 | 
|  | 17 | +Collects multiple layers / functions to be called in sequence | 
|  | 18 | +on given input graph and input node features.  | 
|  | 19 | +
 | 
|  | 20 | +It allows to compose layers in a sequential fashion as `Lux.Chain` | 
|  | 21 | +does, propagating the output of each layer to the next one. | 
|  | 22 | +In addition, `GNNChain` handles the input graph as well, providing it  | 
|  | 23 | +as a first argument only to layers subtyping the [`GNNLayer`](@ref) abstract type.  | 
|  | 24 | +
 | 
|  | 25 | +`GNNChain` supports indexing and slicing, `m[2]` or `m[1:end-1]`, | 
|  | 26 | +and if names are given, `m[:name] == m[1]` etc. | 
|  | 27 | +
 | 
|  | 28 | +# Examples | 
|  | 29 | +```jldoctest | 
|  | 30 | +
 | 
|  | 31 | +julia> using Lux, GNNLux, Random | 
|  | 32 | +
 | 
|  | 33 | +julia> rng = Random.default_rng(); | 
|  | 34 | +
 | 
|  | 35 | +julia> Random.seed!(rng, 0); | 
|  | 36 | +
 | 
|  | 37 | +julia> m = GNNChain(GCNConv(2=>5),  | 
|  | 38 | +                    x -> relu.(x),  | 
|  | 39 | +                    Dense(5=>4)) | 
|  | 40 | +
 | 
|  | 41 | +julia> x = randn(Float32, 2, 3); | 
|  | 42 | +
 | 
|  | 43 | +julia> g = rand_graph(3, 6) | 
|  | 44 | +GNNGraph: | 
|  | 45 | +  num_nodes: 3 | 
|  | 46 | +  num_edges: 6 | 
|  | 47 | +
 | 
|  | 48 | +julia> ps, st = LuxCore.setup(rng,m); | 
|  | 49 | +
 | 
|  | 50 | +julia> m(g,x,ps,st)     # First entry is the output, second entry is the state of the model | 
|  | 51 | +(Float32[-0.15594329 -0.15594329 -0.15594329; 0.93431795 0.93431795 0.93431795; 0.27568763 0.27568763 0.27568763; 0.12568939 0.12568939 0.12568939], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple())) | 
|  | 52 | +``` | 
|  | 53 | +""" | 
| 13 | 54 | @concrete struct GNNChain <: GNNContainerLayer{(:layers,)} | 
| 14 | 55 |     layers <: NamedTuple | 
| 15 | 56 | end | 
|  | 
0 commit comments