Skip to content

Commit 3f3355c

Browse files
support all datasets
1 parent acafe63 commit 3f3355c

File tree

2 files changed

+109
-49
lines changed

2 files changed

+109
-49
lines changed

src/TUDataset/TUDataset.jl

Lines changed: 87 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using DelimitedFiles: readdlm
77
function __init__tudataset()
88
DEPNAME = "TUDataset"
99
LINK = "https://www.chrsmrrs.com/graphkerneldatasets"
10-
DOCS = ""
10+
DOCS = "https://chrsmrrs.github.io/datasets/docs/home/"
1111
DATA = "PROTEINS.zip"
1212

1313
register(DataDep(
@@ -29,7 +29,7 @@ struct TUDataset
2929
source::Vector{Int}
3030
target::Vector{Int}
3131
graph_indicator
32-
node_labels::Vector{Int}
32+
node_labels::Union{Nothing, Vector{Int}}
3333
edge_labels::Union{Nothing, Vector{Int}}
3434
graph_labels
3535
node_attributes
@@ -38,22 +38,24 @@ struct TUDataset
3838
end
3939

4040
"""
41-
TUDataset
41+
TUDataset(name; dir=nothing)
4242
43-
A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY",
44-
"REDDIT-BINARY" or "PROTEINS", collected from the [TU Dortmund University](https://chrsmrrs.github.io/datasets).
43+
A variety of graph benchmark datasets, *.e.g.* "QM9", "IMDB-BINARY",
44+
"REDDIT-BINARY" or "PROTEINS", collected from the [TU Dortmund University](https://chrsmrrs.github.io/datasets/).
45+
Retrieve from TUDataset collection the dataset `name`, where `name`
46+
is any of the datasets available [here](https://chrsmrrs.github.io/datasets/docs/datasets/).
4547
46-
dataset(name; dir=nothing)
48+
A `TUDataset` object can be indexed to retrieve a specific graph or a subset of graphs.
4749
48-
Retrieve the TUDataset dataset. The output is an object with fields
50+
# Internal fields
4951
5052
```
51-
num_nodes
52-
num_edges
53-
num_graphs
53+
num_nodes # total number of nodes (considering all graphs)
54+
num_edges # total number of edges (considering all graphs)
55+
num_graphs # total number of graphs
5456
source # vector of edges' source vectors
5557
target # vector of edges' target vectors
56-
graph_indicator # graph
58+
graph_indicator # graph to which a node belongs too
5759
node_labels
5860
edge_labels
5961
graph_labels
@@ -62,50 +64,72 @@ edge_attributes
6264
graph_attributes
6365
```
6466
65-
See [this link](https://chrsmrrs.github.io/datasets/docs/datasets/)
66-
for a list of the available datasets.
67+
See [here](https://chrsmrrs.github.io/datasets/docs/format/) for an in-depth
68+
description of the format.
69+
70+
# Usage Example
71+
72+
```julia
73+
using MLDatasets: TUDataset
74+
using LightGraphs: SimpleGraph, add_edge!
75+
76+
data = TUDataset("PROTEINS")
77+
78+
# Access first graph
79+
d1 = data[1]
80+
81+
# Create a LightGraphs' graph
82+
g = SimpleGraph(d1.num_nodes)
83+
for (s, t) in zip(d1.source, d1.target)
84+
add_edge!(g, s, t)
85+
end
86+
87+
# Node features
88+
X = d1.node_attributes # (nfeatures x nnodes) matrix
89+
```
6790
"""
6891
function TUDataset(name; dir=nothing)
69-
d = datadir("TUDataset", dir)
70-
# See here for the file format https://chrsmrrs.github.io/datasets/docs/format/
71-
st = readdlm(joinpath(d, name, "$(name)_A.txt"), ',', Int)
72-
92+
d = datadir_tudataset(name, dir)
93+
# See here for the file format: https://chrsmrrs.github.io/datasets/docs/format/
94+
95+
st = readdlm(joinpath(d, "$(name)_A.txt"), ',', Int)
7396
# Check that the first node is labeled 1.
7497
# TODO this will fail if the first node is isolated
7598
@assert minimum(st) == 1
99+
source, target = st[:,1], st[:,2]
76100

77-
graph_indicator = readdlm(joinpath(d, name, "$(name)_graph_indicator.txt"), Int) |> vec
101+
graph_indicator = readdlm(joinpath(d, "$(name)_graph_indicator.txt"), Int) |> vec
78102
@assert all(sort(unique(graph_indicator)) .== 1:length(unique(graph_indicator)))
79103

80-
node_labels = readdlm(joinpath(d, name, "$(name)_node_labels.txt"), Int) |> vec
81-
graph_labels = readdlm(joinpath(d, name, "$(name)_graph_labels.txt"), Int) |> vec
104+
num_nodes = length(graph_indicator)
105+
num_edges = length(source)
106+
num_graphs = length(unique(graph_indicator))
82107

83108
# LOAD OPTIONAL FILES IF EXIST
84109

85-
if isfile(joinpath(d, name, "$(name)_edge_labels.txt"))
86-
edge_labels = readdlm(joinpath(d, name, "$(name)_edge_labels.txt")) |> vec
87-
else
88-
edge_labels = nothing
89-
end
90-
if isfile(joinpath(d, name, "$(name)_node_attributes.txt"))
91-
node_attributes = readdlm(joinpath(d, name, "$(name)_node_attributes.txt"), Float32)' |> collect
92-
else
93-
node_attributes = nothing
94-
end
95-
if isfile(joinpath(d, name, "$(name)_edge_attributes.txt"))
96-
edge_attributes = readdlm(joinpath(d, name, "$(name)_edge_attributes.txt"), Float32)' |> collect
97-
else
98-
edge_attributes = nothing
99-
end
100-
if isfile(joinpath(d, name, "$(name)_graph_attributes.txt"))
101-
graph_attributes = readdlm(joinpath(d, name, "$(name)_graph_attributes.txt"), Float32)' |> collect
102-
else
103-
graph_attributes = nothing
104-
end
105-
110+
node_labels = isfile(joinpath(d, "$(name)_node_labels.txt")) ?
111+
readdlm(joinpath(d, "$(name)_node_labels.txt"), Int) |> vec :
112+
nothing
113+
edge_labels = isfile(joinpath(d, "$(name)_edge_labels.txt")) ?
114+
readdlm(joinpath(d, "$(name)_edge_labels.txt"), Int) |> vec :
115+
nothing
116+
graph_labels = isfile(joinpath(d, "$(name)_graph_labels.txt")) ?
117+
readdlm(joinpath(d, "$(name)_graph_labels.txt"), Int) |> vec :
118+
nothing
119+
120+
node_attributes = isfile(joinpath(d, "$(name)_node_attributes.txt")) ?
121+
readdlm(joinpath(d, "$(name)_node_attributes.txt"), ',', Float32)' |> collect :
122+
nothing
123+
edge_attributes = isfile(joinpath(d, "$(name)_edge_attributes.txt")) ?
124+
readdlm(joinpath(d, "$(name)_edge_attributes.txt"), ',', Float32)' |> collect :
125+
nothing
126+
graph_attributes = isfile(joinpath(d, "$(name)_graph_attributes.txt")) ?
127+
readdlm(joinpath(d, "$(name)_graph_attributes.txt"), ',', Float32)' |> collect :
128+
nothing
129+
106130

107-
TUDataset( length(node_labels), size(st, 1), length(graph_labels),
108-
st[:,1], st[:,2],
131+
TUDataset( num_nodes, num_edges, num_graphs,
132+
source, target,
109133
graph_indicator,
110134
node_labels,
111135
edge_labels,
@@ -115,31 +139,45 @@ function TUDataset(name; dir=nothing)
115139
graph_attributes)
116140
end
117141

142+
function datadir_tudataset(name, dir = nothing)
143+
dir = isnothing(dir) ? datadep"TUDataset" : dir
144+
LINK = "https://www.chrsmrrs.com/graphkerneldatasets/$name.zip"
145+
d = joinpath(dir, name)
146+
if !isdir(d)
147+
DataDeps.fetch_default(LINK, dir)
148+
currdir = pwd()
149+
cd(dir) # Needed since `unpack` extracts in working dir
150+
DataDeps.unpack(joinpath(dir, "$name.zip"))
151+
cd(currdir)
152+
end
153+
@assert isdir(d)
154+
return d
155+
end
118156

119157
function Base.getindex(data::TUDataset, i)
120158
node_mask = data.graph_indicator .∈ Ref(i)
121159
graph_indicator = data.graph_indicator[node_mask]
122160

123161
nodes = (1:data.num_nodes)[node_mask]
124-
node_labels = data.node_labels[node_mask]
162+
node_labels = isnothing(data.node_labels) ? nothing : data.node_labels[node_mask]
125163
nodemap = Dict(v => i for (i, v) in enumerate(nodes))
126164

127165
edge_mask = data.source .∈ Ref(nodes)
128166
source = [nodemap[i] for i in data.source[edge_mask]]
129167
target = [nodemap[i] for i in data.target[edge_mask]]
130168
edge_labels = isnothing(data.edge_labels) ? nothing : data.edge_labels[edge_mask]
131169

132-
graph_labels = data.graph_labels[i]
170+
graph_labels = isnothing(data.graph_labels) ? nothing : data.graph_labels[i]
133171

134172
node_attributes = isnothing(data.node_attributes) ? nothing : data.node_attributes[:,node_mask]
135173
edge_attributes = isnothing(data.edge_attributes) ? nothing : data.edge_attributes[:,edge_mask]
136174
graph_attributes = isnothing(data.graph_attributes) ? nothing : data.graph_attributes[:,i]
137175

176+
num_nodes = length(graph_indicator)
177+
num_edges = length(source)
178+
num_graphs = length(i)
138179

139-
@assert source isa Vector
140-
@assert target isa Vector
141-
@assert node_labels isa Vector
142-
TUDataset(length(nodes), length(source), length(graph_labels),
180+
TUDataset(num_nodes, num_edges, num_graphs,
143181
source, target,
144182
graph_indicator,
145183
node_labels,

test/tst_tudataset.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,25 @@ end
2323
@test length(data.graph_indicator) == data.num_nodes
2424
@test all(sort(unique(data.graph_indicator)) .== 1:data.num_graphs)
2525
end
26+
27+
@testset "TUDataset - QM9" begin
28+
data = TUDataset("QM9")
29+
30+
@test data.num_nodes == 2333625
31+
@test data.num_edges == 4823498
32+
@test data.num_graphs === 129433
33+
34+
@test length(data.source) == data.num_edges
35+
@test length(data.target) == data.num_edges
36+
37+
@test size(data.node_attributes) == (16, data.num_nodes)
38+
@test size(data.edge_attributes) == (4, data.num_edges)
39+
@test size(data.graph_attributes) == (19, data.num_graphs)
40+
41+
@test data.node_labels === nothing
42+
@test data.edge_labels === nothing
43+
@test data.graph_labels === nothing
44+
45+
@test length(data.graph_indicator) == data.num_nodes
46+
@test all(sort(unique(data.graph_indicator)) .== 1:data.num_graphs)
47+
end

0 commit comments

Comments
 (0)