Skip to content

Commit cac775d

Browse files
only cpu working
1 parent 6a5671e commit cac775d

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

examples/neural_ode.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Flux.Losses: logitcrossentropy
55
using Statistics: mean
66
using MLDatasets: Cora
77

8-
device = cpu
8+
device = cpu # `gpu` not working yet
99

1010
# LOAD DATA
1111
data = Cora.dataset()
@@ -15,31 +15,32 @@ y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
1515
train_ids = data.train_indices |> device
1616
val_ids = data.val_indices |> device
1717
test_ids = data.test_indices |> device
18-
ytrain = y[:,train_ids]
18+
ytrain = y[:, train_ids]
1919

2020

2121
# Model and Data Configuration
22-
nin = size(X,1)
22+
nin = size(X, 1)
2323
nhidden = 16
2424
nout = data.num_classes
2525
epochs = 40
2626

2727
# Define the Neural GDE
28-
diffeqarray_to_array(X) = reshape(cpu(X), size(X)[1:2])
28+
diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2])
2929

3030
# GCNConv(nhidden => nhidden, graph=g),
3131

32-
node = NeuralODE(
33-
WithGraph(GCNConv(nhidden => nhidden), g),
34-
(0.f0, 1.f0), Tsit5(), save_everystep = false,
35-
reltol = 1e-3, abstol = 1e-3, save_start = false
36-
)
32+
node_chain = GNNChain(GCNConv(nhidden => nhidden, relu),
33+
GCNConv(nhidden => nhidden, relu)) |> device
34+
35+
node = NeuralODE(WithGraph(node_chain, g),
36+
(0.f0, 1.f0), Tsit5(), save_everystep = false,
37+
reltol = 1e-3, abstol = 1e-3, save_start = false) |> device
3738

3839
model = GNNChain(GCNConv(nin => nhidden, relu),
3940
Dropout(0.5),
4041
node,
4142
diffeqarray_to_array,
42-
GCNConv(nhidden => nout))
43+
Dense(nhidden, nout)) |> device
4344

4445
# Loss
4546
loss(x, y) = logitcrossentropy(model(g, x), y)

0 commit comments

Comments
 (0)