Skip to content

Commit 9a5d726

Browse files
authored
Add traffic prediction tutorial (#324)
* Add `TGCN` and `Flux.Recur` * Add docstring `TGCN` * Add comment * Specify type * Add test * First draft example * Conforme example * Add pluto notebook * Add MLdatasets and Plots * Add traffic tutorial and gif * Fix identation * Fix * Fix other identations * Add legend
1 parent 2e96319 commit 9a5d726

File tree

4 files changed

+472
-0
lines changed

4 files changed

+472
-0
lines changed

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
55
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
66
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
89
MarkdownLiteral = "736d6165-7244-6769-4267-6b50796e6954"
910
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
11+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1012
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
1113
PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad"
1214
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

docs/pluto_output/traffic_prediction.md

Lines changed: 224 additions & 0 deletions
Large diffs are not rendered by default.
1.8 MB
Loading
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
### A Pluto.jl notebook ###
2+
# v0.19.26
3+
4+
#> [frontmatter]
5+
#> author = "[Aurora Rossi](https://github.com/aurorarossi)"
6+
#> title = "Traffic Prediction using recurrent Temporal Graph Convolutional Network"
7+
#> date = "2023-08-21"
8+
#> description = "Traffic Prediction using GraphNeuralNetworks.jl"
9+
#> cover = "assets/traffic.gif"
10+
11+
using Markdown
12+
using InteractiveUtils
13+
14+
# ╔═╡ 177a835d-e91d-4f5f-8484-afb2abad400c
15+
# ╠═╡ show_logs = false
16+
begin
17+
using Pkg
18+
Pkg.develop("GraphNeuralNetworks")
19+
Pkg.add("MLDatasets")
20+
Pkg.add("Plots")
21+
end
22+
23+
# ╔═╡ 1f95ad97-a007-4724-84db-392b0026e1a4
24+
begin
25+
using GraphNeuralNetworks
26+
using Flux
27+
using Flux.Losses: mae
28+
using MLDatasets: METRLA
29+
using Statistics
30+
using Plots
31+
end
32+
33+
# ╔═╡ 5fdab668-4003-11ee-33f5-3953225b0c0f
34+
md"
35+
In this tutorial, we will learn how to use a recurrent Temporal Graph Convolutional Network (TGCN) to predict traffic in a spatio-temporal setting. Traffic forecasting is the problem of predicting future traffic trends on a road network given historical traffic data, such as, in our case, traffic speed and time of day.
36+
"
37+
38+
# ╔═╡ 3dd0ce32-2339-4d5a-9a6f-1f662bc5500b
39+
md"
40+
## Import
41+
42+
We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others.
43+
"
44+
45+
# ╔═╡ ec5caeb6-1f95-4cb9-8739-8cadba29a22d
46+
md"
47+
## Dataset: METR-LA
48+
49+
We use the `METR-LA` dataset from the paper [Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926.pdf), which contains traffic data from loop detectors in the highway of Los Angeles County. The dataset contains traffic speed data from March 1, 2012 to June 30, 2012. The data is collected every 5 minutes, resulting in 12 observations per hour, from 207 sensors. Each sensor is a node in the graph, and the edges represent the distances between the sensors.
50+
"
51+
52+
# ╔═╡ f531e39c-6842-494a-b4ac-8904321098c9
53+
dataset_metrla = METRLA(; num_timesteps = 3)
54+
55+
# ╔═╡ d5ebf9aa-cec8-4417-baaf-f2e8e19f1cad
56+
g = dataset_metrla[1]
57+
58+
# ╔═╡ dc2d5e98-2201-4754-bfc6-8ed2bbb82153
59+
md"
60+
`edge_data` contains the weights of the edges of the graph and
61+
`node_data` contains a node feature vector and a target vector. The latter vectors contain batches of dimension `num_timesteps`, which means that they contain vectors with the node features and targets of `num_timesteps` time steps. Two consecutive batches are shifted by one-time step.
62+
The node features are the traffic speed of the sensors and the time of the day, and the targets are the traffic speed of the sensors in the next time step.
63+
Let's see some examples:
64+
"
65+
66+
# ╔═╡ 0dde5fd3-72d0-4b15-afb3-9a5b102327c9
67+
size(g.node_data.features[1])
68+
69+
# ╔═╡ f7a6d572-28cf-4d69-a9be-d49f367eca37
70+
md"
71+
The first dimension correspond to the two features (first line the speed value and the second line the time of the day), the second to the nodes and the third to the number of timestep `num_timesteps`.
72+
"
73+
74+
# ╔═╡ 3d5503bc-bb97-422e-9465-becc7d3dbe07
75+
size(g.node_data.targets[1])
76+
77+
# ╔═╡ 3569715d-08f5-4605-b946-9ef7ccd86ae5
78+
md"
79+
In the case of the targets the first dimension is 1 because they store just the speed value.
80+
"
81+
82+
# ╔═╡ aa4eb172-2a42-4c01-a6ef-c6c95208d5b2
83+
g.node_data.features[1][:,1,:]
84+
85+
# ╔═╡ 367ed417-4f53-44d4-8135-0c91c842a75f
86+
g.node_data.features[2][:,1,:]
87+
88+
# ╔═╡ 7c084eaa-655c-4251-a342-6b6f4df76ddb
89+
g.node_data.targets[1][:,1,:]
90+
91+
# ╔═╡ bf0d820d-32c0-4731-8053-53d5d499e009
92+
function plot_data(data,sensor)
93+
p = plot(legend=false, xlabel="Time (h)", ylabel="Normalized speed")
94+
plotdata = []
95+
for i in 1:3:length(data)
96+
push!(plotdata,data[i][1,sensor,:])
97+
end
98+
plotdata = reduce(vcat,plotdata)
99+
plot!(p, collect(1:length(data)), plotdata, color = :green, xticks =([i for i in 0:50:250], ["$(i)" for i in 0:4:24]))
100+
return p
101+
end
102+
103+
# ╔═╡ cb89d1a3-b4ff-421a-8717-a0b7f21dea1a
104+
plot_data(g.node_data.features[1:288],1)
105+
106+
# ╔═╡ 3b49a612-3a04-4eb5-bfbc-360614f4581a
107+
md"
108+
Now let's construct the static graph, the temporal features and targets from the dataset.
109+
"
110+
111+
# ╔═╡ 95d8bd24-a40d-409f-a1e7-4174428ef860
112+
begin
113+
graph = GNNGraph(g.edge_index; edata = g.edge_data, g.num_nodes)
114+
features = g.node_data.features
115+
targets = g.node_data.targets
116+
end;
117+
118+
# ╔═╡ fde2ac9e-b121-4105-8428-1820b9c17a43
119+
md"
120+
Now let's construct the `train_loader` and `data_loader`.
121+
"
122+
123+
124+
# ╔═╡ 111b7d5d-c7e3-44c0-9e5e-2ed1a86854d3
125+
begin
126+
train_loader = zip(features[1:200], targets[1:200])
127+
test_loader = zip(features[2001:2288], targets[2001:2288])
128+
end;
129+
130+
# ╔═╡ 572a6633-875b-4d7e-9afc-543b442948fb
131+
md"
132+
## Model: T-GCN
133+
134+
We use the T-GCN model from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction] (https://arxiv.org/pdf/1811.05320.pdf), which consists of a graph convolutional network (GCN) and a gated recurrent unit (GRU). The GCN is used to capture spatial features from the graph, and the GRU is used to capture temporal features from the feature time series.
135+
"
136+
137+
# ╔═╡ 5502f4fa-3201-4980-b766-2ab88b175b11
138+
model = GNNChain(TGCN(2 => 100), Dense(100, 1))
139+
140+
# ╔═╡ 4a1ec34a-1092-4b4a-b8a8-bd91939ffd9e
141+
md"
142+
![](https://www.researchgate.net/profile/Haifeng-Li-3/publication/335353434/figure/fig4/AS:851870352437249@1580113127759/The-architecture-of-the-Gated-Recurrent-Unit-model.jpg)
143+
"
144+
145+
# ╔═╡ 755a88c2-c2e5-46d1-9582-af4b2c5a6bbd
146+
md"
147+
## Training
148+
149+
We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the mean absolute error (MAE) as the loss function.
150+
"
151+
152+
# ╔═╡ e83253b2-9f3a-44e2-a747-cce1661657c4
153+
function train(graph, train_loader, model)
154+
155+
opt = Flux.setup(Adam(0.001), model)
156+
157+
for epoch in 1:100
158+
for (x, y) in train_loader
159+
x, y = (x, y)
160+
grads = Flux.gradient(model) do model
161+
= model(graph, x)
162+
Flux.mae(ŷ, y)
163+
end
164+
Flux.update!(opt, model, grads[1])
165+
end
166+
167+
if epoch % 10 == 0
168+
loss = mean([Flux.mae(model(graph,x), y) for (x, y) in train_loader])
169+
@show epoch, loss
170+
end
171+
end
172+
return model
173+
end
174+
175+
# ╔═╡ 85a923da-3027-4f71-8db6-96852c115c03
176+
train(graph, train_loader, model)
177+
178+
# ╔═╡ 39c82234-97ea-48d6-98dd-915f072b7f85
179+
function plot_predicted_data(graph,features,targets, sensor)
180+
p = plot(xlabel="Time (h)", ylabel="Normalized speed")
181+
prediction = []
182+
grand_truth = []
183+
for i in 1:3:length(features)
184+
push!(grand_truth,targets[i][1,sensor,:])
185+
push!(prediction, model(graph, features[i])[1,sensor,:])
186+
end
187+
prediction = reduce(vcat,prediction)
188+
grand_truth = reduce(vcat, grand_truth)
189+
plot!(p, collect(1:length(features)), grand_truth, color = :blue, label = "Grand Truth", xticks =([i for i in 0:50:250], ["$(i)" for i in 0:4:24]))
190+
plot!(p, collect(1:length(features)), prediction, color = :red, label= "Prediction")
191+
return p
192+
end
193+
194+
# ╔═╡ 8c3a903b-2c8a-4d4f-8eef-74d5611f2ce4
195+
plot_predicted_data(graph,features[301:588],targets[301:588], 1)
196+
197+
# ╔═╡ 2c5f6250-ee7a-41b1-9551-bcfeba83ca8b
198+
accuracy(ŷ, y) = 1 - Statistics.norm(y-ŷ)/Statistics.norm(y)
199+
200+
# ╔═╡ 1008dad4-d784-4c38-a7cf-d9b64728e28d
201+
mean([accuracy(model(graph,x), y) for (x, y) in test_loader])
202+
203+
# ╔═╡ 8d0e8b9f-226f-4bff-9deb-046e6a897b71
204+
md"The accuracy is not very good but can be improved by training using more data. We used a small subset of the dataset for this tutorial because of the computational cost of training the model. From the plot of the predictions, we can see that the model is able to capture the general trend of the traffic speed, but it is not able to capture the peaks of the traffic."
205+
206+
# ╔═╡ a7e4bb23-6687-476a-a0c2-1b2736873d9d
207+
md"
208+
## Conclusion
209+
210+
In this tutorial, we learned how to use a recurrent temporal graph convolutional network to predict traffic in a spatio-temporal setting. We used the TGCN model, which consists of a graph convolutional network (GCN) and a gated recurrent unit (GRU). We then trained the model for 100 epochs on a small subset of the METR-LA dataset. The accuracy of the model is not very good, but it can be improved by training on more data.
211+
"
212+
213+
# ╔═╡ Cell order:
214+
# ╟─5fdab668-4003-11ee-33f5-3953225b0c0f
215+
# ╠═177a835d-e91d-4f5f-8484-afb2abad400c
216+
# ╟─3dd0ce32-2339-4d5a-9a6f-1f662bc5500b
217+
# ╠═1f95ad97-a007-4724-84db-392b0026e1a4
218+
# ╟─ec5caeb6-1f95-4cb9-8739-8cadba29a22d
219+
# ╠═f531e39c-6842-494a-b4ac-8904321098c9
220+
# ╠═d5ebf9aa-cec8-4417-baaf-f2e8e19f1cad
221+
# ╟─dc2d5e98-2201-4754-bfc6-8ed2bbb82153
222+
# ╠═0dde5fd3-72d0-4b15-afb3-9a5b102327c9
223+
# ╟─f7a6d572-28cf-4d69-a9be-d49f367eca37
224+
# ╠═3d5503bc-bb97-422e-9465-becc7d3dbe07
225+
# ╟─3569715d-08f5-4605-b946-9ef7ccd86ae5
226+
# ╠═aa4eb172-2a42-4c01-a6ef-c6c95208d5b2
227+
# ╠═367ed417-4f53-44d4-8135-0c91c842a75f
228+
# ╠═7c084eaa-655c-4251-a342-6b6f4df76ddb
229+
# ╠═bf0d820d-32c0-4731-8053-53d5d499e009
230+
# ╠═cb89d1a3-b4ff-421a-8717-a0b7f21dea1a
231+
# ╟─3b49a612-3a04-4eb5-bfbc-360614f4581a
232+
# ╠═95d8bd24-a40d-409f-a1e7-4174428ef860
233+
# ╟─fde2ac9e-b121-4105-8428-1820b9c17a43
234+
# ╠═111b7d5d-c7e3-44c0-9e5e-2ed1a86854d3
235+
# ╟─572a6633-875b-4d7e-9afc-543b442948fb
236+
# ╠═5502f4fa-3201-4980-b766-2ab88b175b11
237+
# ╟─4a1ec34a-1092-4b4a-b8a8-bd91939ffd9e
238+
# ╟─755a88c2-c2e5-46d1-9582-af4b2c5a6bbd
239+
# ╠═e83253b2-9f3a-44e2-a747-cce1661657c4
240+
# ╠═85a923da-3027-4f71-8db6-96852c115c03
241+
# ╠═39c82234-97ea-48d6-98dd-915f072b7f85
242+
# ╠═8c3a903b-2c8a-4d4f-8eef-74d5611f2ce4
243+
# ╠═2c5f6250-ee7a-41b1-9551-bcfeba83ca8b
244+
# ╠═1008dad4-d784-4c38-a7cf-d9b64728e28d
245+
# ╟─8d0e8b9f-226f-4bff-9deb-046e6a897b71
246+
# ╟─a7e4bb23-6687-476a-a0c2-1b2736873d9d

0 commit comments

Comments
 (0)