|
| 1 | +# Load the packages |
| 2 | +using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations |
| 3 | +using Flux: onehotbatch, onecold, throttle |
| 4 | +using Flux.Losses: logitcrossentropy |
| 5 | +using Statistics: mean |
| 6 | +using MLDatasets: Cora |
| 7 | + |
| 8 | +device = cpu |
| 9 | + |
| 10 | +# LOAD DATA |
| 11 | +data = Cora.dataset() |
| 12 | +g = GNNGraph(data.adjacency_list) |> device |
| 13 | +X = data.node_features |> device |
| 14 | +y = onehotbatch(data.node_labels, 1:data.num_classes) |> device |
| 15 | +train_ids = data.train_indices |> device |
| 16 | +val_ids = data.val_indices |> device |
| 17 | +test_ids = data.test_indices |> device |
| 18 | +ytrain = y[:,train_ids] |
| 19 | + |
| 20 | + |
| 21 | +# Model and Data Configuration |
| 22 | +nin = size(X,1) |
| 23 | +nhidden = 16 |
| 24 | +nout = data.num_classes |
| 25 | +epochs = 40 |
| 26 | + |
| 27 | +# Define the Neural GDE |
| 28 | +diffeqarray_to_array(X) = reshape(cpu(X), size(X)[1:2]) |
| 29 | + |
| 30 | +# GCNConv(nhidden => nhidden, graph=g), |
| 31 | + |
| 32 | +node = NeuralODE( |
| 33 | + WithGraph(GCNConv(nhidden => nhidden), g), |
| 34 | + (0.f0, 1.f0), Tsit5(), save_everystep = false, |
| 35 | + reltol = 1e-3, abstol = 1e-3, save_start = false |
| 36 | +) |
| 37 | + |
| 38 | +model = GNNChain(GCNConv(nin => nhidden, relu), |
| 39 | + Dropout(0.5), |
| 40 | + node, |
| 41 | + diffeqarray_to_array, |
| 42 | + GCNConv(nhidden => nout)) |
| 43 | + |
| 44 | +# Loss |
| 45 | +loss(x, y) = logitcrossentropy(model(g, x), y) |
| 46 | +accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y)) |
| 47 | + |
| 48 | +# Training |
| 49 | +## Model Parameters |
| 50 | +ps = Flux.params(model, node.p); |
| 51 | + |
| 52 | +## Optimizer |
| 53 | +opt = ADAM(0.01) |
| 54 | + |
| 55 | +## Training Loop |
| 56 | +for epoch in 1:epochs |
| 57 | + gs = gradient(() -> loss(X, y), ps) |
| 58 | + Flux.Optimisers.update!(opt, ps, gs) |
| 59 | + @show(accuracy(X, y)) |
| 60 | +end |
0 commit comments