Skip to content

Commit c0b30e2

Browse files
Merge pull request #9 from CarloLucibello/cl/cora
add Cora example
2 parents 4510d98 + bd380f7 commit c0b30e2

File tree

10 files changed

+195
-181
lines changed

10 files changed

+195
-181
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1415
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1516
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

README.md

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,60 +21,7 @@ Some of its noticeable features are the following:
2121
]add GraphNeuralNetworks
2222
```
2323

24-
## Featured Graphs
24+
## Usage
2525

26-
GraphNeuralNetworks handles graph data (the graph topology + node/edge/global features)
27-
thanks to the type `FeaturedGraph`.
28-
29-
A `FeaturedGraph` can be constructed out of
30-
adjacency matrices, adjacency lists, LightGraphs' types...
31-
32-
```julia
33-
fg = FeaturedGraph(adj_list)
34-
```
35-
36-
## Graph convolutional layers
37-
38-
Construct a GCN layer:
39-
40-
```julia
41-
GCNConv(input_dim => output_dim, relu)
42-
```
43-
44-
## Usage Example
45-
46-
```julia
47-
struct GNN
48-
conv1
49-
conv2
50-
dense
51-
end
52-
53-
@functor GNN
54-
55-
function GNN()
56-
GNN(GCNConv(1024=>512, relu),
57-
GCNConv(512=>128, relu),
58-
Dense(128, 10))
59-
end
60-
61-
function (net::GNN)(g, x)
62-
x = net.conv1(g, x)
63-
x = dropout(x, 0.5)
64-
x = net.conv2(g, x)
65-
x = net.dense(x)
66-
return x
67-
end
68-
69-
model = GNN()
70-
71-
loss(x, y) = logitcrossentropy(model(fg, x), y)
72-
accuracy(x, y) = mean(onecold(model(fg, x)) .== onecold(y))
73-
74-
ps = Flux.params(model)
75-
train_data = [(train_X, train_y)]
76-
opt = ADAM(0.01)
77-
evalcb() = @show(accuracy(train_X, train_y))
78-
79-
Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
26+
Usage examples can be found in the `examples/` folder.
8027
```

examples/cora.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# An example of semi-supervised node classification
2+
3+
using Flux
4+
using Flux: @functor, dropout, onecold, onehotbatch
5+
using Flux.Losses: logitcrossentropy
6+
using GraphNeuralNetworks
7+
using MLDatasets: Cora
8+
using Statistics, Random
9+
using CUDA
10+
CUDA.allowscalar(false)
11+
12+
struct GNN
13+
conv1
14+
conv2
15+
dense
16+
end
17+
18+
@functor GNN
19+
20+
function GNN(; nin, nhidden, nout)
21+
GNN(GCNConv(nin => nhidden, relu),
22+
GCNConv(nhidden => nhidden, relu),
23+
Dense(nhidden, nout))
24+
end
25+
26+
function (net::GNN)(fg, x)
27+
x = net.conv1(fg, x)
28+
x = dropout(x, 0.5)
29+
x = net.conv2(fg, x)
30+
x = net.dense(x)
31+
return x
32+
end
33+
34+
function eval_loss_accuracy(X, y, ids, model, fg)
35+
= model(fg, X)
36+
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
37+
acc = mean(onecold(ŷ[:,ids] |> cpu) .== onecold(y[:,ids] |> cpu))
38+
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
39+
end
40+
41+
# arguments for the `train` function
42+
Base.@kwdef mutable struct Args
43+
η = 1f-3 # learning rate
44+
epochs = 100 # number of epochs
45+
seed = 17 # set seed > 0 for reproducibility
46+
use_cuda = true # if true use cuda (if available)
47+
nhidden = 128 # dimension of hidden features
48+
infotime = 10 # report every `infotime` epochs
49+
end
50+
51+
function train(; kws...)
52+
args = Args(; kws...)
53+
if args.seed > 0
54+
Random.seed!(args.seed)
55+
CUDA.seed!(args.seed)
56+
end
57+
58+
if args.use_cuda && CUDA.functional()
59+
device = gpu
60+
@info "Training on GPU"
61+
else
62+
device = cpu
63+
@info "Training on CPU"
64+
end
65+
66+
data = Cora.dataset()
67+
fg = FeaturedGraph(data.adjacency_list) |> device
68+
X = data.node_features |> device
69+
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
70+
train_ids = data.train_indices |> device
71+
val_ids = data.val_indices |> device
72+
test_ids = data.test_indices |> device
73+
ytrain = y[:,train_ids]
74+
75+
model = GNN(nin=size(X,1),
76+
nhidden=args.nhidden,
77+
nout=data.num_classes) |> device
78+
ps = Flux.params(model)
79+
opt = ADAM(args.η)
80+
81+
@info "NUM NODES: $(fg.num_nodes) NUM EDGES: $(fg.num_edges)"
82+
83+
function report(epoch)
84+
train = eval_loss_accuracy(X, y, train_ids, model, fg)
85+
test = eval_loss_accuracy(X, y, test_ids, model, fg)
86+
println("Epoch: $epoch Train: $(train) Test: $(test)")
87+
end
88+
89+
## TRAINING
90+
report(0)
91+
for epoch in 1:args.epochs
92+
gs = Flux.gradient(ps) do
93+
= model(fg, X)
94+
logitcrossentropy(ŷ[:,train_ids], ytrain)
95+
end
96+
97+
Flux.Optimise.update!(opt, ps, gs)
98+
99+
epoch % args.infotime == 0 && report(epoch)
100+
end
101+
end
102+
103+
train()

src/featuredgraph.jl

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,10 @@ function LightGraphs.adjacency_matrix(fg::FeaturedGraph{<:ADJMAT_T}, T::DataType
234234
return dir == :out ? A : A'
235235
end
236236

237-
function LightGraphs.degree(fg::FeaturedGraph{<:COO_T}; dir=:out)
237+
function LightGraphs.degree(fg::FeaturedGraph{<:COO_T}, T=Int; dir=:out)
238238
s, t = edge_index(fg)
239-
degs = fill!(similar(s, eltype(s), fg.num_nodes), 0)
240-
o = fill!(similar(s, eltype(s), fg.num_edges), 1)
239+
degs = fill!(similar(s, T, fg.num_nodes), 0)
240+
o = fill!(similar(s, Int, fg.num_edges), 1)
241241
if dir [:out, :both]
242242
NNlib.scatter!(+, degs, o, s)
243243
end
@@ -247,9 +247,9 @@ function LightGraphs.degree(fg::FeaturedGraph{<:COO_T}; dir=:out)
247247
return degs
248248
end
249249

250-
function LightGraphs.degree(fg::FeaturedGraph{<:ADJMAT_T}; dir=:out)
250+
function LightGraphs.degree(fg::FeaturedGraph{<:ADJMAT_T}, T=Int; dir=:out)
251251
@assert dir (:in, :out)
252-
A = graph(fg)
252+
A = adjacency_matrix(fg, T)
253253
return dir == :out ? vec(sum(A, dims=2)) : vec(sum(A, dims=1))
254254
end
255255

@@ -298,29 +298,32 @@ function LightGraphs.laplacian_matrix(fg::FeaturedGraph, T::DataType=Int; dir::S
298298
end
299299

300300
"""
301-
normalized_laplacian(fg, T=Float32; selfloop=false, dir=:out)
301+
normalized_laplacian(fg, T=Float32; add_self_loops=false, dir=:out)
302302
303303
Normalized Laplacian matrix of graph `g`.
304304
305305
# Arguments
306306
307307
- `fg`: A `FeaturedGraph`.
308308
- `T`: result element type.
309-
- `selfloop`: adding self loop while calculating the matrix.
309+
- `add_self_loops`: add self-loops while calculating the matrix.
310310
- `dir`: the edge directionality considered (:out, :in, :both).
311311
"""
312-
function normalized_laplacian(fg::FeaturedGraph, T::DataType=Float32; selfloop::Bool=false, dir::Symbol=:out)
312+
function normalized_laplacian(fg::FeaturedGraph, T::DataType=Float32;
313+
add_self_loops::Bool=false, dir::Symbol=:out)
314+
= normalized_adjacency(fg, T; dir, add_self_loops)
315+
return I -
316+
end
317+
318+
function normalized_adjacency(fg::FeaturedGraph, T::DataType=Float32;
319+
add_self_loops::Bool=false, dir::Symbol=:out)
313320
A = adjacency_matrix(fg, T; dir=dir)
314-
sz = size(A)
315-
@assert sz[1] == sz[2]
316-
if selfloop
317-
A += I - Diagonal(A)
318-
else
319-
A -= Diagonal(A)
321+
if add_self_loops
322+
A += I
320323
end
321324
degs = vec(sum(A; dims=2))
322325
inv_sqrtD = Diagonal(inv.(sqrt.(degs)))
323-
return I - inv_sqrtD * A * inv_sqrtD
326+
return inv_sqrtD * A * inv_sqrtD
324327
end
325328

326329
@doc raw"""
@@ -354,14 +357,14 @@ _eigmax(A) = KrylovKit.eigsolve(Symmetric(A), 1, :LR)[1][1] # also eigs(A, x0, n
354357
355358
Return a featured graph with the same features as `fg`
356359
but also adding edges connecting the nodes to themselves.
360+
361+
Nodes with already existing
362+
self-loops will obtain a second self-loop.
357363
"""
358364
function add_self_loops(fg::FeaturedGraph{<:COO_T})
359365
s, t = edge_index(fg)
360366
@assert edge_feature(fg) === nothing
361367
@assert edge_weight(fg) === nothing
362-
mask_old_loops = s .!= t
363-
s = s[mask_old_loops]
364-
t = t[mask_old_loops]
365368
n = fg.num_nodes
366369
nodes = convert(typeof(s), [1:n;])
367370
s = [s; nodes]
@@ -371,12 +374,11 @@ function add_self_loops(fg::FeaturedGraph{<:COO_T})
371374
node_feature(fg), edge_feature(fg), global_feature(fg))
372375
end
373376

374-
function add_self_loops(fg::FeaturedGraph{<:ADJMAT_T})
377+
function add_self_loops(fg::FeaturedGraph{<:ADJMAT_T}; add_to_existing=true)
375378
A = graph(fg)
376379
@assert edge_feature(fg) === nothing
377-
nold = sum(Diagonal(A)) |> Int
378-
A = A - Diagonal(A) + I
379-
num_edges = fg.num_edges - nold + fg.num_nodes
380+
A += I
381+
num_edges = fg.num_edges + fg.num_nodes
380382
FeaturedGraph(A, fg.num_nodes, num_edges,
381383
node_feature(fg), edge_feature(fg), global_feature(fg))
382384
end
@@ -396,6 +398,7 @@ function remove_self_loops(fg::FeaturedGraph{<:COO_T})
396398
end
397399

398400
@non_differentiable normalized_laplacian(x...)
401+
@non_differentiable normalized_adjacency(x...)
399402
@non_differentiable scaled_laplacian(x...)
400403
@non_differentiable adjacency_matrix(x...)
401404
@non_differentiable adjacency_list(x...)

src/layers/conv.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,20 @@ end
3434
## but cannot compute the normalized laplacian of sparse cuda matrices yet,
3535
## therefore fallback to message passing framework on gpu for the time being
3636

37-
function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
38-
= normalized_laplacian(fg, eltype(x); selfloop=true)
39-
l.σ.(l.weight * x * .+ l.bias)
37+
function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix{T}) where T
38+
= normalized_adjacency(fg, T; dir=:out, add_self_loops=true)
39+
l.σ.(l.weight * x * .+ l.bias)
4040
end
4141

4242
message(l::GCNConv, xi, xj) = xj
4343
update(l::GCNConv, m, x) = m
4444

45-
function (l::GCNConv)(fg::FeaturedGraph, x::CuMatrix)
45+
function (l::GCNConv)(fg::FeaturedGraph, x::CuMatrix{T}) where T
4646
fg = add_self_loops(fg)
47-
T = eltype(l.weight)
48-
# cout = sqrt.(degree(fg, dir=:out))
49-
cin = 1 ./ reshape(sqrt.(T.(degree(fg, dir=:in))), 1, :)
50-
x = cin .* x
47+
c = 1 ./ sqrt.(degree(fg, T, dir=:in))
48+
x = x .* c'
5149
_, x = propagate(l, fg, nothing, x, nothing, +)
52-
x = cin .* x
50+
x = x .* c'
5351
return l.σ.(l.weight * x .+ l.bias)
5452
end
5553

0 commit comments

Comments
 (0)