Skip to content

Commit 63745a5

Browse files
Merge pull request #55 from CarloLucibello/cl/withgraph
add graph NeuralODE example
2 parents e1114a2 + cac775d commit 63745a5

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

examples/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
4+
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
35
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
46
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
57
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"

examples/neural_ode.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Load the packages
2+
using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations
3+
using Flux: onehotbatch, onecold, throttle
4+
using Flux.Losses: logitcrossentropy
5+
using Statistics: mean
6+
using MLDatasets: Cora
7+
8+
device = cpu # `gpu` not working yet
9+
10+
# LOAD DATA
11+
data = Cora.dataset()
12+
g = GNNGraph(data.adjacency_list) |> device
13+
X = data.node_features |> device
14+
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
15+
train_ids = data.train_indices |> device
16+
val_ids = data.val_indices |> device
17+
test_ids = data.test_indices |> device
18+
ytrain = y[:, train_ids]
19+
20+
21+
# Model and Data Configuration
22+
nin = size(X, 1)
23+
nhidden = 16
24+
nout = data.num_classes
25+
epochs = 40
26+
27+
# Define the Neural GDE
28+
diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2])
29+
30+
# GCNConv(nhidden => nhidden, graph=g),
31+
32+
node_chain = GNNChain(GCNConv(nhidden => nhidden, relu),
33+
GCNConv(nhidden => nhidden, relu)) |> device
34+
35+
node = NeuralODE(WithGraph(node_chain, g),
36+
(0.f0, 1.f0), Tsit5(), save_everystep = false,
37+
reltol = 1e-3, abstol = 1e-3, save_start = false) |> device
38+
39+
model = GNNChain(GCNConv(nin => nhidden, relu),
40+
Dropout(0.5),
41+
node,
42+
diffeqarray_to_array,
43+
Dense(nhidden, nout)) |> device
44+
45+
# Loss
46+
loss(x, y) = logitcrossentropy(model(g, x), y)
47+
accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y))
48+
49+
# Training
50+
## Model Parameters
51+
ps = Flux.params(model, node.p);
52+
53+
## Optimizer
54+
opt = ADAM(0.01)
55+
56+
## Training Loop
57+
for epoch in 1:epochs
58+
gs = gradient(() -> loss(X, y), ps)
59+
Flux.Optimise.update!(opt, ps, gs)
60+
@show(accuracy(X, y))
61+
end

0 commit comments

Comments
 (0)