Skip to content

Commit 3705056

Browse files
Merge pull request #30 from CarloLucibello/cl/dev
mode docs
2 parents 011c956 + 4872eb6 commit 3705056

File tree

10 files changed

+145
-35
lines changed

10 files changed

+145
-35
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ Its most relevant features are:
2121

2222
## Usage
2323

24-
Usage examples can be found in the `examples/` folder.
24+
Usage examples can be found in the [examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) folder.

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ makedocs(;
88
doctest=false, clean=true,
99
sitename = "GraphNeuralNetworks.jl",
1010
pages = ["Home" => "index.md",
11-
"GNNGraph" => "gnngraph.md",
11+
"Graphs" => "gnngraph.md",
1212
"Message Passing" => "messagepassing.md",
1313
"Model Building" => "models.md",
1414
"API Reference" =>

docs/src/index.md

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,84 @@ Its most relevant features are:
1111
* Makes it easy to define custom graph convolutional layers.
1212

1313

14+
## Package overview
1415

16+
Let's give a brief overview of the package solving a
17+
graph regression problem on fake data.
1518

16-
## Package overview
19+
Usage examples on real datasets can be found in the [examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) folder.
1720

1821
### Data preparation
1922

23+
First, we create our dataset consisting in multiple random graphs and associated data features.
24+
that we batch together into a unique graph.
2025

21-
```
22-
using LightGraphs
26+
```juliarepl
27+
julia> using GraphNeuralNetworks, LightGraphs, Flux, CUDA, Statistics
2328
24-
lg = LightGraphs.Graph(5) # create a light's graph graph
25-
add_edge!(g, 1, 2)
26-
add_edge!(g, 1, 3)
27-
add_edge!(g, 2, 4)
28-
add_edge!(g, 2, 5)
29-
add_edge!(g, 3, 4)
29+
julia> all_graphs = GNNGraph[];
3030
31-
g = GNNGraph(g)
31+
julia> for _ in 1:1000
32+
g = GNNGraph(random_regular_graph(10, 4),
33+
ndata=(; x = randn(Float32, 16,10)), # input node features
34+
gdata=(; y = randn(Float32))) # regression target
35+
push!(all_graphs, g)
36+
end
37+
38+
julia> gbatch = Flux.batch(all_graphs)
39+
GNNGraph:
40+
num_nodes = 10000
41+
num_edges = 20000
42+
num_graphs = 1000
43+
ndata:
44+
x => (16, 10000)
45+
edata:
46+
gdata:
47+
y => (1000,)
3248
```
49+
50+
3351
### Model building
3452

53+
We concisely define our model using as a [`GNNChain`](@ref) containing 2 graph convolutaional
54+
layers. If CUDA is available, our model will leave on the gpu.
55+
56+
```juliarepl
57+
julia> device = CUDA.functional() ? Flux.gpu : Flux.cpu;
58+
59+
julia> model = GNNChain(GCNConv(16 => 64),
60+
BatchNorm(64),
61+
x -> relu.(x),
62+
GCNConv(64 => 64, relu),
63+
GlobalPool(mean),
64+
Dense(64, 1)) |> device;
65+
66+
julia> ps = Flux.params(model);
67+
68+
julia> opt = ADAM(1f-4);
69+
```
70+
3571
### Training
3672

73+
```juliarepl
74+
gtrain, _ = getgraph(gbatch, 1:800)
75+
gtest, _ = getgraph(gbatch, 801:gbatch.num_graphs)
76+
train_loader = Flux.Data.DataLoader(gtrain, batchsize=32, shuffle=true)
77+
test_loader = Flux.Data.DataLoader(gtest, batchsize=32, shuffle=false)
78+
79+
function loss(g::GNNGraph)
80+
mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2)
81+
end
82+
83+
loss(loader) = mean(loss(g |> device) for g in loader)
84+
85+
for epoch in 1:100
86+
for g in train_loader
87+
g = g |> gpu
88+
grad = gradient(() -> loss(g), ps)
89+
Flux.Optimise.update!(opt, ps, grad)
90+
end
91+
92+
@info (; epoch, train_loss=loss(train_loader), test_loss=loss(test_loader))
93+
end
94+
```

docs/src/messagepassing.md

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,47 @@ where ``\phi`` is expressed by the [`compute_message`](@ref) function,
1818
``\gamma_x`` and ``\gamma_e`` by [`update_node`](@ref) and [`update_edge`](@ref)
1919
respectively.
2020

21-
See [`GraphConv`](ref) and [`GATConv`](ref)'s implementations as usage examples.
21+
## An example: implementing the GCNConv
22+
23+
Let's (re-)implement the [`GCNConv`](@ref) layer use the message passing framework.
24+
The convolution reads
25+
```math
26+
27+
```math
28+
\mathbf{x}'_i = \sum_{j \in {i} \cup N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
29+
```
30+
where ``c_{ij} = \sqrt{(1+|N(i)|)(1+|N(j)|)}``. We will also add a bias and an activation function.
31+
32+
```julia
33+
using Flux, LightGraphs, GraphNeuralNetworks
34+
import GraphNeuralNetworks: compute_message, update_node, propagate
35+
36+
struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer
37+
weight::A
38+
bias::B
39+
σ::F
40+
end
41+
42+
Flux.@functor GCN # allow collecting params, gpu movement, etc...
43+
44+
function GCN(ch::Pair{Int,Int}, σ=identity)
45+
in, out = ch
46+
W = Flux.glorot_uniform(out, in)
47+
b = zeros(Float32, out)
48+
GCN(W, b, σ)
49+
end
50+
51+
compute_message(l::GCN, xi, xj, eij) = l.weight * xj
52+
update_node(l::GCN, m, x) = m
53+
54+
function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
55+
g = add_self_loops(g)
56+
c = 1 ./ sqrt.(degree(g, T, dir=:in))
57+
x = x .* c'
58+
x, _ = propagate(l, g, +, x)
59+
x = x .* c'
60+
return l.σ.(x .+ l.bias)
61+
end
62+
```
63+
64+

src/gnngraph.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ g = GNNGraph(s, t)
8080
g = GNNGraph(erdos_renyi(100, 20))
8181
8282
# Add 2 node feature arrays
83-
g = GNNGraph(g, ndata = (X = rand(100, g.num_nodes), y = rand(g.num_nodes)))
83+
g = GNNGraph(g, ndata = (x=rand(100, g.num_nodes), y=rand(g.num_nodes)))
8484
85-
# Add node features and edge features with default names `X` and `E`
85+
# Add node features and edge features with default names `x` and `e`
8686
g = GNNGraph(g, ndata = rand(100, g.num_nodes), edata = rand(16, g.num_edges))
8787
88-
g.ndata.X
89-
g.ndata.E
88+
g.ndata.x
89+
g.ndata.e
9090
9191
# Send to gpu
9292
g = g |> gpu
@@ -132,9 +132,9 @@ function GNNGraph(data;
132132

133133
num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1
134134

135-
ndata = normalize_graphdata(ndata, :X)
136-
edata = normalize_graphdata(edata, :E)
137-
gdata = normalize_graphdata(gdata, :U)
135+
ndata = normalize_graphdata(ndata, :x)
136+
edata = normalize_graphdata(edata, :e)
137+
gdata = normalize_graphdata(gdata, :u)
138138

139139
GNNGraph(g,
140140
num_nodes, num_edges, num_graphs,
@@ -170,8 +170,7 @@ function Base.show(io::IO, g::GNNGraph)
170170
println(io, "GNNGraph:
171171
num_nodes = $(g.num_nodes)
172172
num_edges = $(g.num_edges)
173-
num_graphs = $(g.num_graphs)
174-
# feature name => array size")
173+
num_graphs = $(g.num_graphs)")
175174
println(io, " ndata:")
176175
for k in keys(g.ndata)
177176
println(io, " $k => $(size(g.ndata[k]))")
@@ -498,7 +497,7 @@ function node_features(g::GNNGraph)
498497
if isempty(g.ndata)
499498
return nothing
500499
elseif length(g.ndata) > 1
501-
@error "Multiple feature arrays, access directly with g.ndata.X"
500+
@error "Multiple feature arrays, access directly through `g.ndata`"
502501
else
503502
return g.ndata[1]
504503
end
@@ -508,7 +507,7 @@ function edge_features(g::GNNGraph)
508507
if isempty(g.edata)
509508
return nothing
510509
elseif length(g.edata) > 1
511-
@error "Multiple feature arrays, access directly with g.edata.E"
510+
@error "Multiple feature arrays, access directly through `g.edata`"
512511
else
513512
return g.edata[1]
514513
end
@@ -518,7 +517,7 @@ function graph_features(g::GNNGraph)
518517
if isempty(g.gdata)
519518
return nothing
520519
elseif length(g.gdata) > 1
521-
@error "Multiple feature arrays, access directly with g.gdata.U"
520+
@error "Multiple feature arrays, access directly through `g.gdata`"
522521
else
523522
return g.gdata[1]
524523
end

src/layers/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ Graph convolutional layer from paper [Semi-supervised Classification with Graph
55
66
Performs the operation
77
```math
8-
\mathbf{x}'_i = \sum_{j\in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
8+
\mathbf{x}'_i = \sum_{j\in \{i\} \cup N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
99
```
10-
where ``c_{ij} = \sqrt{N(i)\,N(j)}``.
10+
where ``c_{ij} = \sqrt{(1+|N(i)|)(1+|N(j)|)}``.
1111
1212
The input to the layer is a node feature array `X`
1313
of size `(num_features, num_nodes)`.

src/layers/pool.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ and performs the operation
1212
```
1313
where ``V`` is the set of nodes of the input graph and
1414
the type of aggregation represented by ``\square`` is selected by the `aggr` argument.
15-
Commonly used aggregations are are `mean`, `max`, and `+`.
15+
Commonly used aggregations are `mean`, `max`, and `+`.
1616
1717
```julia
1818
using Flux, GraphNeuralNetworks, LightGraphs

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ end
1212

1313
cat_features(x1::Nothing, x2::Nothing) = nothing
1414
cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims=ndims(x1))
15+
cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector}) =
16+
cat(x1, x2, dims=1)
17+
1518

1619
function cat_features(x1::NamedTuple, x2::NamedTuple)
1720
sort(collect(keys(x1))) == sort(collect(keys(x2))) ||

test/gnngraph.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@
118118
@test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]]
119119
@test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]]
120120
@test node_features(g123)[:,11:14] node_features(g2)
121+
122+
# scalar graph features
123+
g1 = GNNGraph(random_regular_graph(10,2), gdata=rand())
124+
g2 = GNNGraph(random_regular_graph(4,2), gdata=rand())
125+
g3 = GNNGraph(random_regular_graph(4,2), gdata=rand())
126+
g123 = Flux.batch([g1, g2, g3])
127+
@test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u]
121128
end
122129

123130
@testset "getgraph" begin

test/layers/conv.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
@test size(node_features(gt_)) == (out_channel, N)
3939

4040
gs = Zygote.gradient(x -> sum(node_features(l(x))), g)[1]
41-
@test size(gs.ndata.X) == size(X)
41+
@test size(gs.ndata.x) == size(X)
4242

4343
gs = Zygote.gradient(model -> sum(node_features(model(g))), l)[1]
4444
@test size(gs.weight) == size(l.weight)
@@ -74,7 +74,7 @@
7474
@test size(node_features(gt_)) == (out_channel, N)
7575

7676
gs = Zygote.gradient(x -> sum(node_features(l(x))), g)[1]
77-
@test size(gs.ndata.X) == size(X)
77+
@test size(gs.ndata.x) == size(X)
7878

7979
gs = Zygote.gradient(model -> sum(node_features(model(g))), l)[1]
8080
@test size(gs.weight) == size(l.weight)
@@ -109,7 +109,7 @@
109109
@test size(node_features(gt_)) == (out_channel, N)
110110

111111
gs = Zygote.gradient(g -> sum(node_features(l(g))), g)[1]
112-
@test size(gs.ndata.X) == size(X)
112+
@test size(gs.ndata.x) == size(X)
113113

114114
gs = Zygote.gradient(model -> sum(node_features(model(g))), l)[1]
115115
@test size(gs.weight1) == size(l.weight1)
@@ -148,7 +148,7 @@
148148
@test size(node_features(gt_)) == (concat ? (out_channel*heads, N) : (out_channel, N))
149149

150150
gs = Zygote.gradient(g -> sum(node_features(gat(g))), g_gat)[1]
151-
@test size(gs.ndata.X) == size(X)
151+
@test size(gs.ndata.x) == size(X)
152152

153153
gs = Zygote.gradient(model -> sum(node_features(model(g_gat))), gat)[1]
154154
@test size(gs.weight) == size(gat.weight)
@@ -182,7 +182,7 @@
182182
@test size(node_features(gt_)) == (out_channel, N)
183183

184184
gs = Zygote.gradient(x -> sum(node_features(ggc(x))), g)[1]
185-
@test size(gs.ndata.X) == size(X)
185+
@test size(gs.ndata.x) == size(X)
186186

187187
gs = Zygote.gradient(model -> sum(node_features(model(g))), ggc)[1]
188188
@test size(gs.weight) == size(ggc.weight)
@@ -206,7 +206,7 @@
206206
@test size(node_features(gt_)) == (out_channel, N)
207207

208208
gs = Zygote.gradient(x -> sum(node_features(ec(x))), g)[1]
209-
@test size(gs.ndata.X) == size(X)
209+
@test size(gs.ndata.x) == size(X)
210210

211211
gs = Zygote.gradient(model -> sum(node_features(model(g))), ec)[1]
212212
@test size(gs.nn.weight) == size(ec.nn.weight)
@@ -226,7 +226,7 @@
226226
@test size(node_features(g_)) == (out_channel, N)
227227

228228
gs = Zygote.gradient(g -> sum(node_features(l(g))), g)[1]
229-
@test size(gs.ndata.X) == size(X)
229+
@test size(gs.ndata.x) == size(X)
230230

231231
gs = Zygote.gradient(model -> sum(node_features(model(g))), l)[1]
232232
@test size(gs.nn.weight) == size(l.nn.weight)

0 commit comments

Comments
 (0)