Skip to content

Commit e8f39ac

Browse files
authored
Add traffic prediction example (#321)
* Add `TGCN` and `Flux.Recur` * Add docstring `TGCN` * Add comment * Specify type * Add test * First draft example * Conforme example * Update * Fix * Fix description and getdataset * Update examples/traffic_prediction.jl
1 parent 23c429e commit e8f39ac

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

examples/traffic_prediction.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Example of using TGCN, a recurrent temporal graph convolutional network of the paper https://arxiv.org/pdf/1811.05320.pdf, for traffic prediction by training it on the METRLA dataset
2+
3+
# Load packages
4+
using Flux
5+
using Flux.Losses: mae
6+
using GraphNeuralNetworks
7+
using MLDatasets: METRLA
8+
using CUDA
9+
using Statistics, Random
10+
CUDA.allowscalar(false)
11+
12+
# Import dataset function
13+
function getdataset()
14+
metrla = METRLA(; num_timesteps = 3)
15+
g = metrla[1]
16+
graph = GNNGraph(g.edge_index; edata = g.edge_data, g.num_nodes)
17+
features = g.node_data.features
18+
targets = g.node_data.targets
19+
train_loader = zip(features[1:2000], targets[1:2000])
20+
test_loader = zip(features[2001:2288], targets[2001:2288])
21+
return graph, train_loader, test_loader
22+
end
23+
24+
# Loss and accuracy functions
25+
lossfunction(ŷ, y) = Flux.mae(ŷ, y)
26+
accuracy(ŷ, y) = 1 - Statistics.norm(y-ŷ)/Statistics.norm(y)
27+
28+
function eval_loss_accuracy(model, graph, data_loader)
29+
error = mean([lossfunction(model(graph,x), y) for (x, y) in data_loader])
30+
acc = mean([accuracy(model(graph,x), y) for (x, y) in data_loader])
31+
return (loss = round(error, digits = 4), acc = round(acc , digits = 4))
32+
end
33+
34+
# Arguments for the train function
35+
Base.@kwdef mutable struct Args
36+
η = 1.0f-3 # learning rate
37+
epochs = 100 # number of epochs
38+
seed = 17 # set seed > 0 for reproducibility
39+
usecuda = true # if true use cuda (if available)
40+
nhidden = 100 # dimension of hidden features
41+
infotime = 20 # report every `infotime` epochs
42+
end
43+
44+
# Train function
45+
function train(; kws...)
46+
args = Args(; kws...)
47+
args.seed > 0 && Random.seed!(args.seed)
48+
49+
if args.usecuda && CUDA.functional()
50+
device = gpu
51+
args.seed > 0 && CUDA.seed!(args.seed)
52+
@info "Training on GPU"
53+
else
54+
device = cpu
55+
@info "Training on CPU"
56+
end
57+
58+
# Define model
59+
model = GNNChain(TGCN(2 => args.nhidden), Dense(args.nhidden, 1)) |> device
60+
61+
opt = Flux.setup(Adam(args.η), model)
62+
63+
graph, train_loader, test_loader = getdataset()
64+
graph = graph |> device
65+
train_loader = train_loader |> device
66+
test_loader = test_loader |> device
67+
68+
function report(epoch)
69+
train_loss, train_acc = eval_loss_accuracy(model, graph, train_loader)
70+
test_loss, test_acc = eval_loss_accuracy(model, graph, test_loader)
71+
println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))")
72+
end
73+
74+
report(0)
75+
for epoch in 1:(args.epochs)
76+
for (x, y) in train_loader
77+
x, y = (x, y)
78+
grads = Flux.gradient(model) do model
79+
= model(graph, x)
80+
lossfunction(y,ŷ)
81+
end
82+
Flux.update!(opt, model, grads[1])
83+
end
84+
85+
args.infotime > 0 && epoch % args.infotime == 0 && report(epoch)
86+
87+
end
88+
return model
89+
end
90+
91+
train()
92+

0 commit comments

Comments
 (0)