|
| 1 | +# Load the packages |
| 2 | +using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations |
| 3 | +using Flux: onehotbatch, onecold, throttle |
| 4 | +using Flux.Losses: logitcrossentropy |
| 5 | +using Statistics: mean |
| 6 | +using MLDatasets: Cora |
| 7 | + |
| 8 | +device = cpu # `gpu` not working yet |
| 9 | + |
| 10 | +# LOAD DATA |
| 11 | +data = Cora.dataset() |
| 12 | +g = GNNGraph(data.adjacency_list) |> device |
| 13 | +X = data.node_features |> device |
| 14 | +y = onehotbatch(data.node_labels, 1:data.num_classes) |> device |
| 15 | +train_ids = data.train_indices |> device |
| 16 | +val_ids = data.val_indices |> device |
| 17 | +test_ids = data.test_indices |> device |
| 18 | +ytrain = y[:, train_ids] |
| 19 | + |
| 20 | + |
| 21 | +# Model and Data Configuration |
| 22 | +nin = size(X, 1) |
| 23 | +nhidden = 16 |
| 24 | +nout = data.num_classes |
| 25 | +epochs = 40 |
| 26 | + |
| 27 | +# Define the Neural GDE |
| 28 | +diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2]) |
| 29 | + |
| 30 | +# GCNConv(nhidden => nhidden, graph=g), |
| 31 | + |
| 32 | +node_chain = GNNChain(GCNConv(nhidden => nhidden, relu), |
| 33 | + GCNConv(nhidden => nhidden, relu)) |> device |
| 34 | + |
| 35 | +node = NeuralODE(WithGraph(node_chain, g), |
| 36 | + (0.f0, 1.f0), Tsit5(), save_everystep = false, |
| 37 | + reltol = 1e-3, abstol = 1e-3, save_start = false) |> device |
| 38 | + |
| 39 | +model = GNNChain(GCNConv(nin => nhidden, relu), |
| 40 | + Dropout(0.5), |
| 41 | + node, |
| 42 | + diffeqarray_to_array, |
| 43 | + Dense(nhidden, nout)) |> device |
| 44 | + |
| 45 | +# Loss |
| 46 | +loss(x, y) = logitcrossentropy(model(g, x), y) |
| 47 | +accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y)) |
| 48 | + |
| 49 | +# Training |
| 50 | +## Model Parameters |
| 51 | +ps = Flux.params(model, node.p); |
| 52 | + |
| 53 | +## Optimizer |
| 54 | +opt = ADAM(0.01) |
| 55 | + |
| 56 | +## Training Loop |
| 57 | +for epoch in 1:epochs |
| 58 | + gs = gradient(() -> loss(X, y), ps) |
| 59 | + Flux.Optimise.update!(opt, ps, gs) |
| 60 | + @show(accuracy(X, y)) |
| 61 | +end |
0 commit comments