Skip to content

Commit 5308002

Browse files
authored
Add METR-LA dataset (#212)
* Add `adjmatrix2edgeindex` * Add some functions for METR-LA * Fix * Improve with weights * Add function * Return weights * Add `generate_task` * Add pkgs and export `metrla` * Update * Add `metrla` tests * Add `METRLA` in docs * Change import * Normal import Statistics * Change name `generate_task` * Fix * Add spaces * Change type `adj` * Fix UD_English
1 parent 98a3855 commit 5308002

File tree

7 files changed

+135
-2
lines changed

7 files changed

+135
-2
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
2424
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2525
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2626
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
27+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2728
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2829

2930
[compat]

docs/src/datasets/graphs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ PolBlogs
3030
PubMed
3131
Reddit
3232
TUDataset
33+
METRLA
3334
```

src/MLDatasets.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using DelimitedFiles: readdlm
1212
using FileIO
1313
import CSV
1414
using LazyModules: @lazy
15+
using Statistics
1516

1617
include("require.jl") # export @require
1718

@@ -25,8 +26,9 @@ include("require.jl") # export @require
2526
@require import DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
2627
@require import ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
2728
@require import Chemfiles = "46823bd8-5fb3-5f92-9aa0-96921f3dd015"
29+
@require import NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
2830

29-
# @lazy import NPZ # lazy imported by FileIO
31+
# lazy imported by FileIO
3032
@lazy import Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
3133
@lazy import MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
3234
@lazy import HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
@@ -128,6 +130,8 @@ include("datasets/graphs/reddit.jl")
128130
export Reddit
129131
include("datasets/graphs/tudataset.jl")
130132
export TUDataset
133+
include("datasets/graphs/metrla.jl")
134+
export METRLA
131135

132136
# Meshes
133137

@@ -147,6 +151,7 @@ function __init__()
147151
__init__pubmed()
148152
__init__reddit()
149153
__init__tudataset()
154+
__init__metrla()
150155

151156
# misc
152157
__init__iris()

src/datasets/graphs/metrla.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
function __init__metrla()
2+
DEPNAME = "METR-LA"
3+
LINK = "https://graphmining.ai/temporal_datasets/"
4+
register(ManualDataDep(DEPNAME,
5+
"""
6+
Dataset: $DEPNAME
7+
Website : $LINK
8+
"""))
9+
end
10+
11+
"""
12+
METRLA(; num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir=nothing)
13+
14+
The METR-LA dataset from the [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926) paper.
15+
16+
`METRLA` is a graph with 207 nodes representing traffic sensors in Los Angeles.
17+
18+
The edge weights `w` are contained as a feature array in `edge_data` and represent the distance between the sensors.
19+
20+
The node features are the traffic speed and the time of the measurements collected by the sensors, divided into `num_timesteps_in` time steps.
21+
22+
The target values are the traffic speed and the time of the measurements collected by the sensors, divided into `num_timesteps_out` time steps.
23+
"""
24+
struct METRLA <: AbstractDataset
25+
graphs::Vector{Graph}
26+
end
27+
28+
function METRLA(;num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir = nothing)
29+
create_default_dir("METR-LA")
30+
d = metrla_datadir(dir)
31+
adj_matrix, node_values = read_metrla(d)
32+
33+
node_values = permutedims(node_values,(1,3,2))
34+
node_values = (node_values .- Statistics.mean(node_values, dims=(3,1))) ./ Statistics.std(node_values, dims=(3,1)) #Z-score normalization
35+
36+
s, t, w = adjmatrix2edgeindex(adj_matrix; weighted = true)
37+
38+
x, y = metrla_generate_task(node_values, num_timesteps_in, num_timesteps_out)
39+
40+
g = Graph(; num_nodes = 207,
41+
edge_index = (s, t),
42+
edge_data = w,
43+
node_data = (features = x, targets = y)
44+
45+
46+
)
47+
return METRLA([g])
48+
end
49+
50+
function metrla_datadir(dir = nothing)
51+
dir = isnothing(dir) ? datadep"METR-LA" : dir
52+
dname = "METR-LA"
53+
LINK = "https://graphmining.ai/temporal_datasets/$dname.zip"
54+
if length(readdir((dir))) == 0
55+
DataDeps.fetch_default(LINK, dir)
56+
currdir = pwd()
57+
cd(dir) # Needed since `unpack` extracts in working dir
58+
DataDeps.unpack(joinpath(dir, "$dname.zip"))
59+
# conditions when unzipped folder is our required data dir
60+
cd(currdir)
61+
end
62+
@assert isdir(dir)
63+
return dir
64+
end
65+
66+
function read_metrla(d::String)
67+
adj_matrix = NPZ.npzread(joinpath(d, "adj_mat.npy"))
68+
node_features = NPZ.npzread(joinpath(d, "node_values.npy"))
69+
return adj_matrix, node_features
70+
end
71+
72+
function metrla_generate_task(node_values::AbstractArray, num_timesteps_in::Int, num_timesteps_out::Int)
73+
indices = [(i, i + num_timesteps_in + num_timesteps_out) for i in 1:(size(node_values,1) - num_timesteps_in - num_timesteps_out)]
74+
features = []
75+
targets = []
76+
for (i,j) in indices
77+
push!(features, node_values[i:i+num_timesteps_in-1,:,:])
78+
push!(targets, reshape(node_values[i+num_timesteps_in:j-1,1,:], (num_timesteps_out, 1, size(node_values, 3))))
79+
end
80+
return features, targets
81+
end
82+
83+
Base.length(d::METRLA) = length(d.graphs)
84+
Base.getindex(d::METRLA, ::Colon) = d.graphs[1]
85+
Base.getindex(d::METRLA, i) = getindex(d.graphs, i)

src/datasets/text/udenglish.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function __init__udenglish()
2222
""",
2323
"https://raw.githubusercontent.com/UniversalDependencies/UD_English-EWT/master/" .*
2424
[TRAINFILE, DEVFILE, TESTFILE],
25-
"e26845c3c78140e15d82a425388bcc58016d511616e5c2669a2e580e8ae586c0"))
25+
"2e94de3333b3b17bd06769ff8c4a81896fabec3ae51a6366f7af82625679f561"))
2626
end
2727

2828
"""

src/graph.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,30 @@ function edgeindex2adjlist(s, t, num_nodes; inneigs = false)
226226
end
227227
return adj
228228
end
229+
230+
function adjmatrix2edgeindex(adj::AbstractMatrix{T}; weighted = true, inneigs = false) where T
231+
s, t = Int[], Int[]
232+
if weighted
233+
w = T[]
234+
end
235+
for i in 1:size(adj,1)
236+
for j in 1:size(adj,2)
237+
if adj[i,j] != 0
238+
push!(s, i)
239+
push!(t, j)
240+
if weighted
241+
push!(w, adj[i,j])
242+
end
243+
end
244+
end
245+
end
246+
247+
if inneigs
248+
s, t = t, s
249+
end
250+
if weighted
251+
return s, t, w
252+
else
253+
return s, t
254+
end
255+
end

test/datasets/graphs_no_ci.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,17 @@ end
338338
data = OrganicMaterialsDB(split = :test)
339339
@test length(data) == 2500
340340
end
341+
342+
@testset "METR-LA" begin
343+
data = METRLA()
344+
@test data isa AbstractDataset
345+
@test length(data) == 1
346+
g = data[1]
347+
@test g === data[:]
348+
@test g isa MLDatasets.Graph
349+
350+
@test g.num_nodes == 207
351+
@test g.num_edges == 1722
352+
@test length(g.node_data.features) == 34248
353+
@test length(g.node_data.targets) == 34248
354+
end

0 commit comments

Comments
 (0)