Skip to content

Commit f07b028

Browse files
add neuralode example
1 parent e1114a2 commit f07b028

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

examples/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
4+
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
35
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
46
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
57
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
68
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
9+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
710
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
811
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

examples/neural_ode.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Load the packages
2+
using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations
3+
using Flux: onehotbatch, onecold, throttle
4+
using Flux.Losses: logitcrossentropy
5+
using Statistics: mean
6+
using MLDatasets: Cora
7+
8+
device = cpu
9+
10+
# LOAD DATA
11+
data = Cora.dataset()
12+
g = GNNGraph(data.adjacency_list) |> device
13+
X = data.node_features |> device
14+
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
15+
train_ids = data.train_indices |> device
16+
val_ids = data.val_indices |> device
17+
test_ids = data.test_indices |> device
18+
ytrain = y[:,train_ids]
19+
20+
21+
# Model and Data Configuration
22+
nin = size(X,1)
23+
nhidden = 16
24+
nout = data.num_classes
25+
epochs = 40
26+
27+
# Define the Neural GDE
28+
diffeqarray_to_array(X) = reshape(cpu(X), size(X)[1:2])
29+
30+
# GCNConv(nhidden => nhidden, graph=g),
31+
32+
node = NeuralODE(
33+
WithGraph(GCNConv(nhidden => nhidden), g),
34+
(0.f0, 1.f0), Tsit5(), save_everystep = false,
35+
reltol = 1e-3, abstol = 1e-3, save_start = false
36+
)
37+
38+
model = GNNChain(GCNConv(nin => nhidden, relu),
39+
Dropout(0.5),
40+
node,
41+
diffeqarray_to_array,
42+
GCNConv(nhidden => nout))
43+
44+
# Loss
45+
loss(x, y) = logitcrossentropy(model(g, x), y)
46+
accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y))
47+
48+
# Training
49+
## Model Parameters
50+
ps = Flux.params(model, node.p);
51+
52+
## Optimizer
53+
opt = ADAM(0.01)
54+
55+
## Training Loop
56+
for epoch in 1:epochs
57+
gs = gradient(() -> loss(X, y), ps)
58+
Flux.Optimisers.update!(opt, ps, gs)
59+
@show(accuracy(X, y))
60+
end

0 commit comments

Comments
 (0)