Skip to content

Commit e260ff9

Browse files
improve GNNLux docs (#542)
1 parent 884c2fa commit e260ff9

File tree

17 files changed

+193
-58
lines changed

17 files changed

+193
-58
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ GraphNeuralNetworks/docs/build
1818
GraphNeuralNetworks/docs/src/GNNGraphs
1919
GraphNeuralNetworks/docs/src/GNNlib
2020
tutorials/docs/build
21+
prova.jl

GNNGraphs/docs/make.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@ makedocs(;
3030
"Home" => "index.md",
3131

3232
"Guides" => [
33-
"Graphs" => [
34-
"guides/gnngraph.md",
35-
"guides/heterograph.md",
36-
"guides/temporalgraph.md"
37-
],
33+
"Graphs" => "guides/gnngraph.md",
34+
"Heterogeneous Graphs" => "guides/heterograph.md",
35+
"Temporal Graphs" => "guides/temporalgraph.md",
3836
"Datasets" => "guides/datasets.md",
3937
],
4038

GNNLux/docs/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
4+
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
45
GNNLux = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
56
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
6-
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
7+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
8+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"

GNNLux/docs/make.jl

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,85 @@
11
using Documenter
2-
using DocumenterInterLinks
3-
using GNNlib
42
using GNNLux
3+
using Lux, GNNGraphs, GNNlib, Graphs
4+
using DocumenterInterLinks
55

6+
DocMeta.setdocmeta!(GNNLux, :DocTestSetup, :(using GNNLux); recursive = true)
67

8+
mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]),
9+
:tex => Dict("inlineMath" => [["\$", "\$"], ["\\(", "\\)"]],
10+
"packages" => [
11+
"base",
12+
"ams",
13+
"autoload",
14+
"mathtools",
15+
"require"
16+
])))
717

8-
assets=[]
9-
prettyurls = get(ENV, "CI", nothing) == "true"
10-
mathengine = MathJax3()
1118

1219
interlinks = InterLinks(
13-
"GNNGraphs" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNGraphs/", joinpath(dirname(dirname(@__DIR__)), "GNNGraphs", "docs", "build", "objects.inv")),
14-
"GNNlib" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNlib/", joinpath(dirname(dirname(@__DIR__)), "GNNlib", "docs", "build", "objects.inv")))
15-
20+
"NNlib" => "https://fluxml.ai/NNlib.jl/stable/",
21+
# "GNNGraphs" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNGraphs/", joinpath(dirname(dirname(@__DIR__)), "GNNGraphs", "docs", "build", "objects.inv")),
22+
# "GNNlib" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNlib/", joinpath(dirname(dirname(@__DIR__)), "GNNlib", "docs", "build", "objects.inv"))
23+
)
24+
25+
# Copy the docs from GNNGraphs and GNNlib. Will be removed at the end of the script
26+
cp(joinpath(@__DIR__, "../../GNNGraphs/docs/src"),
27+
joinpath(@__DIR__, "src/GNNGraphs"), force=true)
28+
cp(joinpath(@__DIR__, "../../GNNlib/docs/src"),
29+
joinpath(@__DIR__, "src/GNNlib"), force=true)
30+
1631
makedocs(;
17-
modules = [GNNLux],
18-
doctest = false,
19-
clean = true,
32+
modules = [GNNLux, GNNGraphs, GNNlib],
33+
doctest = false, # TODO: enable doctest
2034
plugins = [interlinks],
21-
format = Documenter.HTML(; mathengine, prettyurls, assets = assets, size_threshold=nothing),
35+
format = Documenter.HTML(; mathengine,
36+
prettyurls = get(ENV, "CI", nothing) == "true",
37+
assets = [],
38+
size_threshold=nothing,
39+
size_threshold_warn=2000000),
2240
sitename = "GNNLux.jl",
23-
pages = ["Home" => "index.md",
24-
"API Reference" => [
25-
"Basic" => "api/basic.md",
26-
"Convolutional layers" => "api/conv.md",
27-
"Temporal Convolutional layers" => "api/temporalconv.md",],
28-
]
29-
)
30-
41+
pages = [
42+
43+
"Home" => "index.md",
44+
45+
"Guides" => [
46+
"Graphs" => "GNNGraphs/guides/gnngraph.md",
47+
"Message Passing" => "GNNlib/guides/messagepassing.md",
48+
"Models" => "guides/models.md",
49+
"Datasets" => "GNNGraphs/guides/datasets.md",
50+
"Heterogeneous Graphs" => "GNNGraphs/guides/heterograph.md",
51+
"Temporal Graphs" => "GNNGraphs/guides/temporalgraph.md",
52+
],
53+
54+
"API Reference" => [
55+
"Graphs (GNNGraphs.jl)" => [
56+
"GNNGraph" => "GNNGraphs/api/gnngraph.md",
57+
"GNNHeteroGraph" => "GNNGraphs/api/heterograph.md",
58+
"TemporalSnapshotsGNNGraph" => "GNNGraphs/api/temporalgraph.md",
59+
"Samplers" => "GNNGraphs/api/samplers.md",
60+
]
61+
62+
"Message Passing (GNNlib.jl)" => [
63+
"Message Passing" => "GNNlib/api/messagepassing.md",
64+
"Other Operators" => "GNNlib/api/utils.md",
65+
]
66+
67+
"Layers" => [
68+
"Basic layers" => "api/basic.md",
69+
"Convolutional layers" => "api/conv.md",
70+
# "Pooling layers" => "api/pool.md",
71+
"Temporal Convolutional layers" => "api/temporalconv.md",
72+
# "Hetero Convolutional layers" => "api/heteroconv.md",
73+
]
74+
],
75+
76+
# "Developer guide" => "dev.md",
77+
],
78+
)
79+
80+
rm(joinpath(@__DIR__, "src/GNNGraphs"), force=true, recursive=true)
81+
rm(joinpath(@__DIR__, "src/GNNlib"), force=true, recursive=true)
3182

32-
deploydocs(;repo = "github.com/JuliaGraphs/GraphNeuralNetworks.jl.git", devbranch = "master", dirname = "GNNLux")
83+
deploydocs(repo = "github.com/JuliaGraphs/GraphNeuralNetworks.jl.git",
84+
devbranch = "master",
85+
dirname = "GNNLux")

GNNLux/docs/src/api/basic.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ CurrentModule = GNNLux
33
CollapsedDocStrings = true
44
```
55

6-
## Basic Layers
6+
# Basic Layers
77

88
```@docs
99
GNNLayer

GNNLux/docs/src/api/conv.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ The table below lists all graph convolutional layers implemented in the *GNNLux.
3535
| [`SAGEConv`](@ref) || | |||
3636
| [`SGConv`](@ref) || | | ||
3737

38-
## Docs
3938

4039
```@autodocs
4140
Modules = [GNNLux]

GNNLux/docs/src/api/temporalconv.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ CollapsedDocStrings = true
77

88
Convolutions for time-varying graphs (temporal graphs) such as the [`TemporalSnapshotsGNNGraph`](@ref).
99

10-
## Docs
11-
1210
```@autodocs
1311
Modules = [GNNLux]
1412
Pages = ["layers/temporalconv.jl"]

GNNLux/docs/src/guides/models.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Models
2+
3+
GNNLux.jl provides common graph convolutional layers by which you can assemble arbitrarily deep or complex models. GNN layers are compatible with
4+
Lux.jl ones, therefore expert Lux users are promptly able to define and train
5+
their models.
6+
7+
In what follows, we discuss two different styles for model creation:
8+
the *explicit modeling* style, more verbose but more flexible,
9+
and the *implicit modeling* style based on [`GNNLux.GNNChain`](@ref), more concise but less flexible.
10+
11+
## Explicit modeling
12+
13+
In the explicit modeling style, the model is created according to the following steps:
14+
15+
1. Define a new type for your model (`GNN` in the example below). Refer to the
16+
[Lux Manual](https://lux.csail.mit.edu/dev/manual/interface#lux-interface) for the
17+
definition of the type.
18+
2. Define a convenience constructor for your model.
19+
4. Define the forward pass by implementing the call method for your type.
20+
5. Instantiate the model.
21+
22+
Here is an example of this construction:
23+
```julia
24+
using Lux, GNNLux
25+
using Zygote
26+
using Random, Statistics
27+
28+
struct GNN <: AbstractLuxContainerLayer{(:conv1, :bn, :conv2, :dropout, :dense)} # step 1
29+
conv1
30+
bn
31+
conv2
32+
dropout
33+
dense
34+
end
35+
36+
function GNN(din::Int, d::Int, dout::Int) # step 2
37+
GNN(GraphConv(din => d),
38+
BatchNorm(d),
39+
GraphConv(d => d, relu),
40+
Dropout(0.5),
41+
Dense(d, dout))
42+
end
43+
44+
function (model::GNN)(g::GNNGraph, x, ps, st) # step 3
45+
x, st_conv1 = model.conv1(g, x, ps.conv1, st.conv1)
46+
x, st_bn = model.bn(x, ps.bn, st.bn)
47+
x = relu.(x)
48+
x, st_conv2 = model.conv2(g, x, ps.conv2, st.conv2)
49+
x, st_drop = model.dropout(x, ps.dropout, st.dropout)
50+
x, st_dense = model.dense(x, ps.dense, st.dense)
51+
return x, (conv1=st_conv1, bn=st_bn, conv2=st_conv2, dropout=st_drop, dense=st_dense)
52+
end
53+
54+
din, d, dout = 3, 4, 2
55+
model = GNN(din, d, dout) # step 4
56+
rng = Random.default_rng()
57+
ps, st = Lux.setup(rng, model)
58+
g = rand_graph(rng, 10, 30)
59+
X = randn(Float32, din, 10)
60+
61+
st = Lux.testmode(st)
62+
y, st = model(g, X, ps, st)
63+
st = Lux.trainmode(st)
64+
grad = Zygote.gradient(ps -> mean(model(g, X, ps, st)[1]), ps)[1]
65+
```
66+
67+
## Implicit modeling with GNNChains
68+
69+
While very flexible, the way in which we defined `GNN` model definition in last section is a bit verbose.
70+
In order to simplify things, we provide the [`GNNLux.GNNChain`](@ref) type. It is very similar
71+
to Lux's well known `Chain`. It allows to compose layers in a sequential fashion as Chain
72+
does, propagating the output of each layer to the next one. In addition, `GNNChain`
73+
propagates the input graph as well, providing it as a first argument
74+
to layers subtyping the [`GNNLux.GNNLayer`](@ref) abstract type.
75+
76+
Using `GNNChain`, the model definition becomes more concise:
77+
78+
```julia
79+
model = GNNChain(GraphConv(din => d),
80+
BatchNorm(d),
81+
x -> relu.(x),
82+
GraphConv(d => d, relu),
83+
Dropout(0.5),
84+
Dense(d, dout))
85+
```
86+
87+
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.

GNNLux/docs/src/index.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
# GNNLux.jl
22

3-
GNNLux.jl is a work-in-progress package that implements stateless graph convolutional layers, fully compatible with the [Lux.jl](https://lux.csail.mit.edu/stable/) machine learning framework. It is built on top of the GNNGraphs.jl, GNNlib.jl, and Lux.jl packages.
3+
GNNLux.jl is a package that implements graph convolutional layers fully compatible with the [Lux.jl](https://lux.csail.mit.edu/stable/) deep learning framework. It is built on top of the GNNGraphs.jl, GNNlib.jl, and Lux.jl packages.
4+
5+
See [GraphNeuralNetworks.jl](https://juliagraphs.org/GraphNeuralNetworks.jl/graphneuralnetworks/) instead for a
6+
[Flux.jl](https://fluxml.ai/Flux.jl/stable/)-based implementation of graph neural networks.
7+
8+
## Installation
9+
10+
GNNLux.jl is a registered Julia package. You can easily install it through the package manager :
11+
12+
```julia
13+
pkg> add GNNLux
14+
```
415

516
## Package overview
617

718
Let's give a brief overview of the package by solving a graph regression problem with synthetic data.
819

20+
921
### Data preparation
1022

1123
We generate a dataset of multiple random graphs with associated data features, then split it into training and testing sets.

GraphNeuralNetworks/docs/make.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ makedocs(;
4343
"Home" => "index.md",
4444

4545
"Guides" => [
46-
"Graphs" => ["GNNGraphs/guides/gnngraph.md",
47-
"GNNGraphs/guides/heterograph.md",
48-
"GNNGraphs/guides/temporalgraph.md"],
46+
"Graphs" => "GNNGraphs/guides/gnngraph.md",
4947
"Message Passing" => "GNNlib/guides/messagepassing.md",
5048
"Models" => "guides/models.md",
5149
"Datasets" => "GNNGraphs/guides/datasets.md",
50+
"Heterogeneous Graphs" => "GNNGraphs/guides/heterograph.md",
51+
"Temporal Graphs" => "GNNGraphs/guides/temporalgraph.md",
5252
],
5353

5454
"Tutorials" => [

0 commit comments

Comments
 (0)