Skip to content

Commit bdf0ad1

Browse files
committed
Add basic docs
1 parent 5e519e9 commit bdf0ad1

File tree

1 file changed

+43
-2
lines changed

1 file changed

+43
-2
lines changed

GNNLux/src/layers/basic.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,55 @@
22
abstract type GNNLayer <: AbstractLuxLayer end
33
44
An abstract type from which graph neural network layers are derived.
5-
It is Derived from Lux's `AbstractLuxLayer` type.
5+
It is derived from Lux's `AbstractLuxLayer` type.
66
7-
See also `GNNChain`.
7+
See also [`GNNLux.GNNChain`](@ref).
88
"""
99
abstract type GNNLayer <: AbstractLuxLayer end
1010

1111
abstract type GNNContainerLayer{T} <: AbstractLuxContainerLayer{T} end
1212

13+
"""
14+
GNNChain(layers...)
15+
GNNChain(name = layer, ...)
16+
17+
Collects multiple layers / functions to be called in sequence
18+
on given input graph and input node features.
19+
20+
It allows to compose layers in a sequential fashion as `Lux.Chain`
21+
does, propagating the output of each layer to the next one.
22+
In addition, `GNNChain` handles the input graph as well, providing it
23+
as a first argument only to layers subtyping the [`GNNLayer`](@ref) abstract type.
24+
25+
`GNNChain` supports indexing and slicing, `m[2]` or `m[1:end-1]`,
26+
and if names are given, `m[:name] == m[1]` etc.
27+
28+
# Examples
29+
```jldoctest
30+
31+
julia> using Lux, GNNLux, Random
32+
33+
julia> rng = Random.default_rng();
34+
35+
julia> Random.seed!(rng, 0);
36+
37+
julia> m = GNNChain(GCNConv(2=>5),
38+
x -> relu.(x),
39+
Dense(5=>4))
40+
41+
julia> x = randn(Float32, 2, 3);
42+
43+
julia> g = rand_graph(3, 6)
44+
GNNGraph:
45+
num_nodes: 3
46+
num_edges: 6
47+
48+
julia> ps, st = LuxCore.setup(rng,m);
49+
50+
julia> m(g,x,ps,st) # First entry is the output, second entry is the state of the model
51+
(Float32[-0.15594329 -0.15594329 -0.15594329; 0.93431795 0.93431795 0.93431795; 0.27568763 0.27568763 0.27568763; 0.12568939 0.12568939 0.12568939], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
52+
```
53+
"""
1354
@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
1455
layers <: NamedTuple
1556
end

0 commit comments

Comments
 (0)