Skip to content

Commit 8441b94

Browse files
doc fixes
1 parent 4872eb6 commit 8441b94

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

docs/src/index.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Usage examples on real datasets can be found in the [examples](https://github.co
2323
First, we create our dataset consisting in multiple random graphs and associated data features.
2424
that we batch together into a unique graph.
2525

26-
```juliarepl
26+
```julia
2727
julia> using GraphNeuralNetworks, LightGraphs, Flux, CUDA, Statistics
2828

2929
julia> all_graphs = GNNGraph[];
@@ -51,9 +51,9 @@ GNNGraph:
5151
### Model building
5252

5353
We concisely define our model using as a [`GNNChain`](@ref) containing 2 graph convolutaional
54-
layers. If CUDA is available, our model will leave on the gpu.
54+
layers. If CUDA is available, our model will live on the gpu.
5555

56-
```juliarepl
56+
```julia
5757
julia> device = CUDA.functional() ? Flux.gpu : Flux.cpu;
5858

5959
julia> model = GNNChain(GCNConv(16 => 64),
@@ -70,15 +70,17 @@ julia> opt = ADAM(1f-4);
7070

7171
### Training
7272

73-
```juliarepl
73+
Finally, we use a standard Flux training pipeling to fit our dataset.
74+
Flux's DataLoader iterates over mini-batches of graphs
75+
(batched together into a `GNNGraph` object).
76+
77+
```julia
7478
gtrain, _ = getgraph(gbatch, 1:800)
7579
gtest, _ = getgraph(gbatch, 801:gbatch.num_graphs)
7680
train_loader = Flux.Data.DataLoader(gtrain, batchsize=32, shuffle=true)
7781
test_loader = Flux.Data.DataLoader(gtest, batchsize=32, shuffle=false)
7882

79-
function loss(g::GNNGraph)
80-
mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2)
81-
end
83+
loss(g::GNNGraph) = mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2)
8284

8385
loss(loader) = mean(loss(g |> device) for g in loader)
8486

docs/src/messagepassing.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@ respectively.
2222

2323
Let's (re-)implement the [`GCNConv`](@ref) layer use the message passing framework.
2424
The convolution reads
25-
```math
2625

2726
```math
28-
\mathbf{x}'_i = \sum_{j \in {i} \cup N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
27+
\mathbf{x}'_i = \sum_{j \in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
2928
```
30-
where ``c_{ij} = \sqrt{(1+|N(i)|)(1+|N(j)|)}``. We will also add a bias and an activation function.
29+
where ``c_{ij} = \sqrt{(|N(i)|)(|N(j)|)}``. We will also add a bias and an activation function.
3130

3231
```julia
3332
using Flux, LightGraphs, GraphNeuralNetworks
@@ -52,7 +51,6 @@ compute_message(l::GCN, xi, xj, eij) = l.weight * xj
5251
update_node(l::GCN, m, x) = m
5352

5453
function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
55-
g = add_self_loops(g)
5654
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5755
x = x .* c'
5856
x, _ = propagate(l, g, +, x)
@@ -61,4 +59,6 @@ function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
6159
end
6260
```
6361

62+
See the [`GATConv`](@ref) implementation [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl) for a more complex example.
63+
6464

0 commit comments

Comments
 (0)