Skip to content

Commit 6ad8d81

Browse files
authored
GNNLux docs start and general docs improvement (#513)
1 parent b1d7936 commit 6ad8d81

File tree

14 files changed

+251
-31
lines changed

14 files changed

+251
-31
lines changed

GNNGraphs/docs/src/api/temporalgraph.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Pages = ["temporalsnapshotsgnngraph.jl"]
1010
Private = false
1111
```
1212

13-
### TemporalSnapshotsGNNGraph random generators
13+
## TemporalSnapshotsGNNGraph random generators
1414

1515
```@docs
1616
rand_temporal_radius_graph

GNNGraphs/docs/src/index.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# GNNGraphs.jl
22

3-
GNNGraphs.jl is a package that provides graph data structures and helper functions specifically designed for working with graph neural networks. This package allows to store not only the graph structure, but also features associated with nodes, edges, and the graph itself. It is the core foundation for the GNNlib, GraphNeuralNetworks, and GNNLux packages.
3+
GNNGraphs.jl is a package that provides graph data structures and helper functions specifically designed for working with graph neural networks. This package allows to store not only the graph structure, but also features associated with nodes, edges, and the graph itself. It is the core foundation for the GNNlib.jl, GraphNeuralNetworks.jl, and GNNLux.jl packages.
44

55
It supports three types of graphs:
66

@@ -12,4 +12,16 @@ It supports three types of graphs:
1212

1313

1414

15-
This package depends on the package [Graphs.jl] (https://github.com/JuliaGraphs/Graphs.jl).
15+
This package depends on the package [Graphs.jl] (https://github.com/JuliaGraphs/Graphs.jl).
16+
17+
18+
19+
## Installation
20+
21+
The package can be installed with the Julia package manager.
22+
From the Julia REPL, type `]` to enter the Pkg REPL mode and run:
23+
24+
```julia
25+
pkg> add GNNGraphs
26+
```
27+

GNNLux/docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
34
GNNLux = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
45
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
56
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"

GNNLux/docs/make.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Documenter
2+
using DocumenterInterLinks
23
using GNNlib
34
using GNNLux
45

@@ -8,11 +9,15 @@ assets=[]
89
prettyurls = get(ENV, "CI", nothing) == "true"
910
mathengine = MathJax3()
1011

11-
12+
interlinks = InterLinks(
13+
"GNNGraphs" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNGraphs/", joinpath(dirname(dirname(@__DIR__)), "GNNGraphs", "docs", "build", "objects.inv")),
14+
"GNNlib" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNlib/", joinpath(dirname(dirname(@__DIR__)), "GNNlib", "docs", "build", "objects.inv")))
15+
1216
makedocs(;
1317
modules = [GNNLux],
1418
doctest = false,
1519
clean = true,
20+
plugins = [interlinks],
1621
format = Documenter.HTML(; mathengine, prettyurls, assets = assets, size_threshold=nothing),
1722
sitename = "GNNLux.jl",
1823
pages = ["Home" => "index.md",

GNNLux/docs/src/api/basic.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ CurrentModule = GNNLux
44

55
## GNNLayer
66
```@docs
7-
GNNLux.GNNLayer
7+
GNNLayer
8+
GNNChain
89
```

GNNLux/docs/src/api/conv.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
```@meta
2+
CurrentModule = GNNLux
3+
```
4+
5+
# Convolutional Layers
6+
7+
Many different types of graphs convolutional layers have been proposed in the literature. Choosing the right layer for your application could involve a lot of exploration.
8+
Multiple graph convolutional layers are typically stacked together to create a graph neural network model (see [`GNNChain`](@ref)).
9+
10+
The table below lists all graph convolutional layers implemented in the *GNNLux.jl*. It also highlights the presence of some additional capabilities with respect to basic message passing:
11+
- *Sparse Ops*: implements message passing as multiplication by sparse adjacency matrix instead of the gather/scatter mechanism. This can lead to better CPU performances but it is not supported on GPU yet.
12+
- *Edge Weight*: supports scalar weights (or equivalently scalar features) on edges.
13+
- *Edge Features*: supports feature vectors on edges.
14+
- *Heterograph*: supports heterogeneous graphs (see [`GNNHeteroGraph`](@ref)).
15+
- *TemporalSnapshotsGNNGraphs*: supports temporal graphs (see [`TemporalSnapshotsGNNGraph`](@ref)) by applying the convolution layers to each snapshot independently.
16+
17+
| Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
18+
| :-------- | :---: |:---: |:---: | :---: | :---: ||
19+
| [`GCNConv`](@ref) ||| || |
20+
21+
## Docs
22+
23+
```@autodocs
24+
Modules = [GNNLux]
25+
Pages = ["layers/conv.jl"]
26+
Private = false
27+
```

GNNLux/src/layers/basic.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,52 @@
22
abstract type GNNLayer <: AbstractLuxLayer end
33
44
An abstract type from which graph neural network layers are derived.
5-
It is Derived from Lux's `AbstractLuxLayer` type.
5+
It is derived from Lux's `AbstractLuxLayer` type.
66
7-
See also `GNNChain`.
7+
See also [`GNNLux.GNNChain`](@ref).
88
"""
99
abstract type GNNLayer <: AbstractLuxLayer end
1010

1111
abstract type GNNContainerLayer{T} <: AbstractLuxContainerLayer{T} end
1212

13+
"""
14+
GNNChain(layers...)
15+
GNNChain(name = layer, ...)
16+
17+
Collects multiple layers / functions to be called in sequence
18+
on given input graph and input node features.
19+
20+
It allows to compose layers in a sequential fashion as `Lux.Chain`
21+
does, propagating the output of each layer to the next one.
22+
In addition, `GNNChain` handles the input graph as well, providing it
23+
as a first argument only to layers subtyping the [`GNNLayer`](@ref) abstract type.
24+
25+
`GNNChain` supports indexing and slicing, `m[2]` or `m[1:end-1]`,
26+
and if names are given, `m[:name] == m[1]` etc.
27+
28+
# Examples
29+
```jldoctest
30+
julia> using Lux, GNNLux, Random
31+
32+
julia> rng = Random.default_rng();
33+
34+
julia> m = GNNChain(GCNConv(2=>5),
35+
x -> relu.(x),
36+
Dense(5=>4))
37+
38+
julia> x = randn(rng, Float32, 2, 3);
39+
40+
julia> g = rand_graph(rng, 3, 6)
41+
GNNGraph:
42+
num_nodes: 3
43+
num_edges: 6
44+
45+
julia> ps, st = LuxCore.setup(rng, m);
46+
47+
julia> m(g, x, ps, st) # First entry is the output, second entry is the state of the model
48+
(Float32[-0.15594329 -0.15594329 -0.15594329; 0.93431795 0.93431795 0.93431795; 0.27568763 0.27568763 0.27568763; 0.12568939 0.12568939 0.12568939], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
49+
```
50+
"""
1351
@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
1452
layers <: NamedTuple
1553
end

GNNLux/src/layers/conv.jl

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,80 @@ _getstate(s::StatefulLuxLayer{Static.True}) = s.st
55
_getstate(s::StatefulLuxLayer{false}) = s.st_any
66
_getstate(s::StatefulLuxLayer{Static.False}) = s.st_any
77

8-
8+
@doc raw"""
9+
GCNConv(in => out, σ=identity; [init_weight, init_bias, use_bias, add_self_loops, use_edge_weight])
10+
11+
Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
12+
13+
Performs the operation
14+
```math
15+
\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j
16+
```
17+
where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees.
18+
19+
If the input graph has weighted edges and `use_edge_weight=true`, than ``a_{ij}`` will be computed as
20+
```math
21+
a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}}
22+
```
23+
24+
# Arguments
25+
26+
- `in`: Number of input features.
27+
- `out`: Number of output features.
28+
- `σ`: Activation function. Default `identity`.
29+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
30+
- `init_bias`: Bias initializer. Default `zeros32`.
31+
- `use_bias`: Add learnable bias. Default `true`.
32+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
33+
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
34+
If `add_self_loops=true` the new weights will be set to 1.
35+
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
36+
Default `false`.
37+
38+
# Forward
39+
40+
(::GCNConv)(g, x, [edge_weight], ps, st; norm_fn = d -> 1 ./ sqrt.(d), conv_weight=nothing)
41+
42+
Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`, optionally an edge weight vector and the parameter and state of the layer. Returns a node feature matrix of size
43+
`[out, num_nodes]`.
44+
45+
The `norm_fn` parameter allows for custom normalization of the graph convolution operation by passing a function as argument.
46+
By default, it computes ``\frac{1}{\sqrt{d}}`` i.e the inverse square root of the degree (`d`) of each node in the graph.
47+
If `conv_weight` is an `AbstractMatrix` of size `[out, in]`, then the convolution is performed using that weight matrix.
48+
49+
# Examples
50+
51+
```julia
52+
using GNNLux, Lux, Random
53+
# initialize random number generator
54+
rng = Random.default_rng()
55+
# create data
56+
s = [1,1,2,3]
57+
t = [2,3,1,1]
58+
g = GNNGraph(s, t)
59+
x = randn(rng, Float32, 3, g.num_nodes)
60+
61+
# create layer
62+
l = GCNConv(3 => 5)
63+
64+
# setup layer
65+
ps, st = LuxCore.setup(rng, l)
66+
67+
# forward pass
68+
y = l(g, x, ps, st) # size of the output first entry: 5 × num_nodes
69+
70+
# convolution with edge weights and custom normalization function
71+
w = [1.1, 0.1, 2.3, 0.5]
72+
custom_norm_fn(d) = 1 ./ sqrt.(d + 1) # Custom normalization function
73+
y = l(g, x, w, ps, st; norm_fn = custom_norm_fn)
74+
75+
# Edge weights can also be embedded in the graph.
76+
g = GNNGraph(s, t, w)
77+
l = GCNConv(3 => 5, use_edge_weight=true)
78+
ps, st = Lux.setup(rng, l)
79+
y = l(g, x, ps, st) # same as l(g, x, w)
80+
```
81+
"""
982
@concrete struct GCNConv <: GNNLayer
1083
in_dims::Int
1184
out_dims::Int
@@ -18,7 +91,7 @@ _getstate(s::StatefulLuxLayer{Static.False}) = s.st_any
1891
end
1992

2093
function GCNConv(ch::Pair{Int, Int}, σ = identity;
21-
init_weight = glorot_uniform,
94+
init_weight = glorot_uniform,
2295
init_bias = zeros32,
2396
use_bias::Bool = true,
2497
add_self_loops::Bool = true,
@@ -55,7 +128,7 @@ end
55128

56129
function (l::GCNConv)(g, x, edge_weight, ps, st;
57130
norm_fn = d -> 1 ./ sqrt.(d),
58-
conv_weight=nothing, )
131+
conv_weight=nothing)
59132

60133
m = (; ps.weight, bias = _getbias(ps),
61134
l.add_self_loops, l.use_edge_weight, l.σ)

GNNlib/docs/src/index.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
# GNNlib.jl
22

33
GNNlib.jl is a package that provides the implementation of the basic message passing functions and
4-
functional implementation of graph convolutional layers, which are used to build graph neural networks in both the Flux.jl and Lux.jl machine learning frameworks, created in the GraphNeuralNetworks.jl and GNNLux.jl packages, respectively.
4+
functional implementation of graph convolutional layers, which are used to build graph neural networks in both the [Flux.jl](https://fluxml.ai/Flux.jl/stable/) and [Lux.jl](https://lux.csail.mit.edu/stable/) machine learning frameworks, created in the GraphNeuralNetworks.jl and GNNLux.jl packages, respectively.
55

6-
This package depends on GNNGraphs.jl and NNlib.jl, and is primarily intended for developers looking to create new GNN architectures. For most users, the higher-level GraphNeuralNetworks.jl and GNNLux.jl packages are recommended.
6+
This package depends on GNNGraphs.jl and NNlib.jl, and is primarily intended for developers looking to create new GNN architectures. For most users, the higher-level GraphNeuralNetworks.jl and GNNLux.jl packages are recommended.
7+
8+
## Installation
9+
10+
The package can be installed with the Julia package manager.
11+
From the Julia REPL, type `]` to enter the Pkg REPL mode and run:
12+
13+
```julia
14+
pkg> add GNNlib
15+
```

GNNlib/docs/src/messagepassing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
134134
end
135135
```
136136

137-
See the `GATConv` implementation [here](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl) for a more complex example.
137+
See the `GATConv` implementation [here](https://juliagraphs.org/GraphNeuralNetworks.jl/graphneuralnetworks/api/conv/) for a more complex example.
138138

139139

140140
## Built-in message functions

0 commit comments

Comments
 (0)