@@ -5,7 +5,7 @@ using Flux.Losses: logitcrossentropy
5
5
using Statistics: mean
6
6
using MLDatasets: Cora
7
7
8
- device = cpu
8
+ device = cpu # `gpu` not working yet
9
9
10
10
# LOAD DATA
11
11
data = Cora. dataset ()
@@ -15,31 +15,32 @@ y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
15
15
train_ids = data. train_indices |> device
16
16
val_ids = data. val_indices |> device
17
17
test_ids = data. test_indices |> device
18
- ytrain = y[:,train_ids]
18
+ ytrain = y[:, train_ids]
19
19
20
20
21
21
# Model and Data Configuration
22
- nin = size (X,1 )
22
+ nin = size (X, 1 )
23
23
nhidden = 16
24
24
nout = data. num_classes
25
25
epochs = 40
26
26
27
27
# Define the Neural GDE
28
- diffeqarray_to_array (X ) = reshape (cpu (X ), size (X )[1 : 2 ])
28
+ diffeqsol_to_array (x ) = reshape (device (x ), size (x )[1 : 2 ])
29
29
30
30
# GCNConv(nhidden => nhidden, graph=g),
31
31
32
- node = NeuralODE (
33
- WithGraph (GCNConv (nhidden => nhidden), g),
34
- (0.f0 , 1.f0 ), Tsit5 (), save_everystep = false ,
35
- reltol = 1e-3 , abstol = 1e-3 , save_start = false
36
- )
32
+ node_chain = GNNChain (GCNConv (nhidden => nhidden, relu),
33
+ GCNConv (nhidden => nhidden, relu)) |> device
34
+
35
+ node = NeuralODE (WithGraph (node_chain, g),
36
+ (0.f0 , 1.f0 ), Tsit5 (), save_everystep = false ,
37
+ reltol = 1e-3 , abstol = 1e-3 , save_start = false ) |> device
37
38
38
39
model = GNNChain (GCNConv (nin => nhidden, relu),
39
40
Dropout (0.5 ),
40
41
node,
41
42
diffeqarray_to_array,
42
- GCNConv (nhidden => nout))
43
+ Dense (nhidden, nout)) |> device
43
44
44
45
# Loss
45
46
loss (x, y) = logitcrossentropy (model (g, x), y)
0 commit comments