Skip to content

Commit 19771ab

Browse files
authored
Fix traffic datasets (#216)
* Change datasets database and improvement * Add new tests and fix tasks generation * Fix
1 parent f33ebca commit 19771ab

File tree

4 files changed

+41
-49
lines changed

4 files changed

+41
-49
lines changed

src/datasets/graphs/metrla.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
function __init__metrla()
2-
DEPNAME = "METR-LA"
3-
LINK = "https://graphmining.ai/temporal_datasets/"
2+
DEPNAME = "METRLA"
3+
LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/index.html"
44
register(ManualDataDep(DEPNAME,
55
"""
66
Dataset: $DEPNAME
@@ -9,7 +9,7 @@ function __init__metrla()
99
end
1010

1111
"""
12-
METRLA(; num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir=nothing)
12+
METRLA(; num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir=nothing, normalize = true)
1313
1414
The METR-LA dataset from the [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926) paper.
1515
@@ -19,22 +19,22 @@ The edge weights `w` are contained as a feature array in `edge_data` and represe
1919
2020
The node features are the traffic speed and the time of the measurements collected by the sensors, divided into `num_timesteps_in` time steps.
2121
22-
The target values are the traffic speed and the time of the measurements collected by the sensors, divided into `num_timesteps_out` time steps.
22+
The target values are the traffic speed of the measurements collected by the sensors, divided into `num_timesteps_out` time steps.
23+
24+
The `normalize` flag indicates whether the data are normalized using Z-score normalization.
2325
"""
2426
struct METRLA <: AbstractDataset
2527
graphs::Vector{Graph}
2628
end
2729

28-
function METRLA(;num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir = nothing)
29-
s, t, w, x, y = processed_traffic("METR-LA", num_timesteps_in, num_timesteps_out, dir)
30+
function METRLA(;num_timesteps::Int = 12, dir = nothing, normalize = true)
31+
s, t, w, x, y = processed_traffic("METRLA", num_timesteps, dir, normalize)
3032

3133
g = Graph(; num_nodes = 207,
3234
edge_index = (s, t),
3335
edge_data = w,
34-
node_data = (features = x, targets = y)
35-
36-
37-
)
36+
node_data = (features = x, targets = y))
37+
3838
return METRLA([g])
3939
end
4040

src/datasets/graphs/pemsbay.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
function __init__pemsbay()
2-
DEPNAME = "PEMS-BAY"
3-
LINK = "https://graphmining.ai/temporal_datasets/"
2+
DEPNAME = "PEMSBAY"
3+
LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/index.html"
44
register(ManualDataDep(DEPNAME,
55
"""
66
Dataset: $DEPNAME
@@ -9,7 +9,7 @@ function __init__pemsbay()
99
end
1010

1111
"""
12-
PEMSBAY(; num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir=nothing)
12+
PEMSBAY(; num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir=nothing, normalize = true)
1313
1414
The PEMS-BAY dataset described in the [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926) paper.
1515
It is collected by California Transportation Agencies (Cal-
@@ -21,22 +21,22 @@ The edge weights `w` are contained as a feature array in `edge_data` and represe
2121
2222
The node features are the traffic speed and the time of the measurements collected by the sensors, divided into `num_timesteps_in` time steps.
2323
24-
The target values are the traffic speed and the time of the measurements collected by the sensors, divided into `num_timesteps_out` time steps.
24+
The target values are the traffic speed of the measurements collected by the sensors, divided into `num_timesteps_out` time steps.
25+
26+
The `normalize` flag indicates whether the data are normalized using Z-score normalization.
2527
"""
2628
struct PEMSBAY <: AbstractDataset
2729
graphs::Vector{Graph}
2830
end
2931

30-
function PEMSBAY(;num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir = nothing)
31-
s, t, w, x, y = processed_traffic("PEMS-BAY", num_timesteps_in, num_timesteps_out, dir)
32+
function PEMSBAY(;num_timesteps::Int = 12, dir = nothing, normalize = true)
33+
s, t, w, x, y = processed_traffic("PEMSBAY", num_timesteps, dir, normalize)
3234

3335
g = Graph(; num_nodes = 325,
3436
edge_index = (s, t),
3537
edge_data = w,
36-
node_data = (features = x, targets = y)
37-
38-
39-
)
38+
node_data = (features = x, targets = y))
39+
4040
return PEMSBAY([g])
4141
end
4242

src/datasets/graphs/traffic.jl

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
function traffic_datadir(dname ::String, dir = nothing)
2-
if dname == "PEMS-BAY"
3-
dir = isnothing(dir) ? datadep"PEMS-BAY" : dir
4-
elseif dname == "METR-LA"
5-
dir = isnothing(dir) ? datadep"METR-LA" : dir
2+
if dname == "PEMSBAY"
3+
dir = isnothing(dir) ? datadep"PEMSBAY" : dir
4+
elseif dname == "METRLA"
5+
dir = isnothing(dir) ? datadep"METRLA" : dir
66
end
7-
LINK = "https://graphmining.ai/temporal_datasets/$dname.zip"
7+
LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/data/$dname.zip"
88
if length(readdir((dir))) == 0
99
DataDeps.fetch_default(LINK, dir)
1010
currdir = pwd()
@@ -18,39 +18,33 @@ function traffic_datadir(dname ::String, dir = nothing)
1818
end
1919

2020
function read_traffic(d::String, dname::String)
21-
if dname == "PEMS-BAY"
22-
s="pems_"
23-
elseif dname == "METR-LA"
24-
s=""
25-
end
26-
27-
adj_matrix = NPZ.npzread(joinpath(d, "$(s)adj_mat.npy"))
28-
node_features = NPZ.npzread(joinpath(d, "$(s)node_values.npy"))
29-
21+
adj_matrix = load(joinpath(d, "$(dname).jld2"), "adj_matrix")
22+
node_features = load(joinpath(d, "$(dname).jld2"), "node_features")
23+
node_features = permutedims(node_features,(2,3,1))
3024
return adj_matrix, node_features
3125
end
3226

33-
function traffic_generate_task(node_values::AbstractArray, num_timesteps_in::Int, num_timesteps_out::Int)
34-
indices = [(i, i + num_timesteps_in + num_timesteps_out) for i in 1:(size(node_values,1) - num_timesteps_in - num_timesteps_out)]
27+
function traffic_generate_task(node_values::AbstractArray, num_timesteps::Int)
3528
features = []
3629
targets = []
37-
for (i,j) in indices
38-
push!(features, node_values[i:i+num_timesteps_in-1,:,:])
39-
push!(targets, reshape(node_values[i+num_timesteps_in:j-1,1,:], (num_timesteps_out, 1, size(node_values, 3))))
30+
for i in 1:size(node_values,3)-num_timesteps
31+
push!(features, node_values[:,:,i:i+num_timesteps-1])
32+
push!(targets, reshape(node_values[1,:,i+1:i+num_timesteps], (1, size(node_values, 2),num_timesteps)))
4033
end
4134
return features, targets
4235
end
4336

44-
function processed_traffic(dname::String, num_timesteps_in::Int, num_timesteps_out::Int, dir = nothing)
37+
function processed_traffic(dname::String, num_timesteps::Int, dir = nothing, normalize = true)
4538
create_default_dir(dname)
4639
d = traffic_datadir(dname, dir)
4740
adj_matrix, node_values = read_traffic(d, dname)
4841

49-
node_values = permutedims(node_values,(1,3,2))
50-
node_values = (node_values .- Statistics.mean(node_values, dims=(3,1))) ./ Statistics.std(node_values, dims=(3,1)) #Z-score normalization
51-
42+
if normalize
43+
node_values = (node_values .- Statistics.mean(node_values, dims=(3,1))) ./ Statistics.std(node_values, dims=(3,1)) #Z-score normalization
44+
end
45+
5246
s, t, w = adjmatrix2edgeindex(adj_matrix; weighted = true)
53-
54-
x, y = traffic_generate_task(node_values, num_timesteps_in, num_timesteps_out)
47+
x, y = traffic_generate_task(node_values, num_timesteps )
48+
5549
return s, t, w, x, y
5650
end

test/datasets/graphs_no_ci.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,7 @@ end
349349

350350
@test g.num_nodes == 207
351351
@test g.num_edges == 1722
352-
@test length(g.node_data.features) == 34248
353-
@test length(g.node_data.targets) == 34248
352+
@test all(g.node_data.features[1][:,:,1][2:end,1] == g.node_data.targets[1][:,:,1][1:end-1])
354353
end
355354

356355
@testset "PEMS-BAY" begin
@@ -363,6 +362,5 @@ end
363362

364363
@test g.num_nodes == 325
365364
@test g.num_edges == 2694
366-
@test length(g.node_data.features) == 52081
367-
@test length(g.node_data.targets) == 52081
365+
@test all(g.node_data.features[1][:,:,1][2:end,1] == g.node_data.targets[1][:,:,1][1:end-1])
368366
end

0 commit comments

Comments
 (0)