@@ -7,7 +7,7 @@ using DelimitedFiles: readdlm
7
7
function __init__tudataset ()
8
8
DEPNAME = " TUDataset"
9
9
LINK = " https://www.chrsmrrs.com/graphkerneldatasets"
10
- DOCS = " "
10
+ DOCS = " https://chrsmrrs.github.io/datasets/docs/home/ "
11
11
DATA = " PROTEINS.zip"
12
12
13
13
register (DataDep (
@@ -29,7 +29,7 @@ struct TUDataset
29
29
source:: Vector{Int}
30
30
target:: Vector{Int}
31
31
graph_indicator
32
- node_labels:: Vector{Int}
32
+ node_labels:: Union{Nothing, Vector{Int} }
33
33
edge_labels:: Union{Nothing, Vector{Int}}
34
34
graph_labels
35
35
node_attributes
@@ -38,22 +38,24 @@ struct TUDataset
38
38
end
39
39
40
40
"""
41
- TUDataset
41
+ TUDataset(name; dir=nothing)
42
42
43
- A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY",
44
- "REDDIT-BINARY" or "PROTEINS", collected from the [TU Dortmund University](https://chrsmrrs.github.io/datasets).
43
+ A variety of graph benchmark datasets, *.e.g.* "QM9", "IMDB-BINARY",
44
+ "REDDIT-BINARY" or "PROTEINS", collected from the [TU Dortmund University](https://chrsmrrs.github.io/datasets/).
45
+ Retrieve from TUDataset collection the dataset `name`, where `name`
46
+ is any of the datasets available [here](https://chrsmrrs.github.io/datasets/docs/datasets/).
45
47
46
- dataset(name; dir=nothing)
48
+ A `TUDataset` object can be indexed to retrieve a specific graph or a subset of graphs.
47
49
48
- Retrieve the TUDataset dataset. The output is an object with fields
50
+ # Internal fields
49
51
50
52
```
51
- num_nodes
52
- num_edges
53
- num_graphs
53
+ num_nodes # total number of nodes (considering all graphs)
54
+ num_edges # total number of edges (considering all graphs)
55
+ num_graphs # total number of graphs
54
56
source # vector of edges' source vectors
55
57
target # vector of edges' target vectors
56
- graph_indicator # graph
58
+ graph_indicator # graph to which a node belongs too
57
59
node_labels
58
60
edge_labels
59
61
graph_labels
@@ -62,50 +64,72 @@ edge_attributes
62
64
graph_attributes
63
65
```
64
66
65
- See [this link](https://chrsmrrs.github.io/datasets/docs/datasets/)
66
- for a list of the available datasets.
67
+ See [here](https://chrsmrrs.github.io/datasets/docs/format/) for an in-depth
68
+ description of the format.
69
+
70
+ # Usage Example
71
+
72
+ ```julia
73
+ using MLDatasets: TUDataset
74
+ using LightGraphs: SimpleGraph, add_edge!
75
+
76
+ data = TUDataset("PROTEINS")
77
+
78
+ # Access first graph
79
+ d1 = data[1]
80
+
81
+ # Create a LightGraphs' graph
82
+ g = SimpleGraph(d1.num_nodes)
83
+ for (s, t) in zip(d1.source, d1.target)
84
+ add_edge!(g, s, t)
85
+ end
86
+
87
+ # Node features
88
+ X = d1.node_attributes # (nfeatures x nnodes) matrix
89
+ ```
67
90
"""
68
91
function TUDataset (name; dir= nothing )
69
- d = datadir ( " TUDataset " , dir)
70
- # See here for the file format https://chrsmrrs.github.io/datasets/docs/format/
71
- st = readdlm ( joinpath (d, name, " $(name) _A.txt " ), ' , ' , Int)
72
-
92
+ d = datadir_tudataset (name , dir)
93
+ # See here for the file format: https://chrsmrrs.github.io/datasets/docs/format/
94
+
95
+ st = readdlm ( joinpath (d, " $(name) _A.txt " ), ' , ' , Int)
73
96
# Check that the first node is labeled 1.
74
97
# TODO this will fail if the first node is isolated
75
98
@assert minimum (st) == 1
99
+ source, target = st[:,1 ], st[:,2 ]
76
100
77
- graph_indicator = readdlm (joinpath (d, name, " $(name) _graph_indicator.txt" ), Int) |> vec
101
+ graph_indicator = readdlm (joinpath (d, " $(name) _graph_indicator.txt" ), Int) |> vec
78
102
@assert all (sort (unique (graph_indicator)) .== 1 : length (unique (graph_indicator)))
79
103
80
- node_labels = readdlm (joinpath (d, name, " $(name) _node_labels.txt" ), Int) |> vec
81
- graph_labels = readdlm (joinpath (d, name, " $(name) _graph_labels.txt" ), Int) |> vec
104
+ num_nodes = length (graph_indicator)
105
+ num_edges = length (source)
106
+ num_graphs = length (unique (graph_indicator))
82
107
83
108
# LOAD OPTIONAL FILES IF EXIST
84
109
85
- if isfile (joinpath (d, name, " $(name) _edge_labels.txt" ))
86
- edge_labels = readdlm (joinpath (d, name, " $(name) _edge_labels.txt" )) |> vec
87
- else
88
- edge_labels = nothing
89
- end
90
- if isfile (joinpath (d, name, " $(name) _node_attributes.txt" ))
91
- node_attributes = readdlm (joinpath (d, name, " $(name) _node_attributes.txt" ), Float32)' |> collect
92
- else
93
- node_attributes = nothing
94
- end
95
- if isfile (joinpath (d, name, " $(name) _edge_attributes.txt" ))
96
- edge_attributes = readdlm (joinpath (d, name, " $(name) _edge_attributes.txt" ), Float32)' |> collect
97
- else
98
- edge_attributes = nothing
99
- end
100
- if isfile (joinpath (d, name, " $(name) _graph_attributes.txt" ))
101
- graph_attributes = readdlm (joinpath (d, name, " $(name) _graph_attributes.txt" ), Float32)' |> collect
102
- else
103
- graph_attributes = nothing
104
- end
105
-
110
+ node_labels = isfile (joinpath (d, " $(name) _node_labels.txt" )) ?
111
+ readdlm (joinpath (d, " $(name) _node_labels.txt" ), Int) |> vec :
112
+ nothing
113
+ edge_labels = isfile (joinpath (d, " $(name) _edge_labels.txt" )) ?
114
+ readdlm (joinpath (d, " $(name) _edge_labels.txt" ), Int) |> vec :
115
+ nothing
116
+ graph_labels = isfile (joinpath (d, " $(name) _graph_labels.txt" )) ?
117
+ readdlm (joinpath (d, " $(name) _graph_labels.txt" ), Int) |> vec :
118
+ nothing
119
+
120
+ node_attributes = isfile (joinpath (d, " $(name) _node_attributes.txt" )) ?
121
+ readdlm (joinpath (d, " $(name) _node_attributes.txt" ), ' ,' , Float32)' |> collect :
122
+ nothing
123
+ edge_attributes = isfile (joinpath (d, " $(name) _edge_attributes.txt" )) ?
124
+ readdlm (joinpath (d, " $(name) _edge_attributes.txt" ), ' ,' , Float32)' |> collect :
125
+ nothing
126
+ graph_attributes = isfile (joinpath (d, " $(name) _graph_attributes.txt" )) ?
127
+ readdlm (joinpath (d, " $(name) _graph_attributes.txt" ), ' ,' , Float32)' |> collect :
128
+ nothing
129
+
106
130
107
- TUDataset ( length (node_labels), size (st, 1 ), length (graph_labels) ,
108
- st[:, 1 ], st[:, 2 ] ,
131
+ TUDataset ( num_nodes, num_edges, num_graphs ,
132
+ source, target ,
109
133
graph_indicator,
110
134
node_labels,
111
135
edge_labels,
@@ -115,31 +139,45 @@ function TUDataset(name; dir=nothing)
115
139
graph_attributes)
116
140
end
117
141
142
+ function datadir_tudataset (name, dir = nothing )
143
+ dir = isnothing (dir) ? datadep " TUDataset" : dir
144
+ LINK = " https://www.chrsmrrs.com/graphkerneldatasets/$name .zip"
145
+ d = joinpath (dir, name)
146
+ if ! isdir (d)
147
+ DataDeps. fetch_default (LINK, dir)
148
+ currdir = pwd ()
149
+ cd (dir) # Needed since `unpack` extracts in working dir
150
+ DataDeps. unpack (joinpath (dir, " $name .zip" ))
151
+ cd (currdir)
152
+ end
153
+ @assert isdir (d)
154
+ return d
155
+ end
118
156
119
157
function Base. getindex (data:: TUDataset , i)
120
158
node_mask = data. graph_indicator .∈ Ref (i)
121
159
graph_indicator = data. graph_indicator[node_mask]
122
160
123
161
nodes = (1 : data. num_nodes)[node_mask]
124
- node_labels = data. node_labels[node_mask]
162
+ node_labels = isnothing (data . node_labels) ? nothing : data. node_labels[node_mask]
125
163
nodemap = Dict (v => i for (i, v) in enumerate (nodes))
126
164
127
165
edge_mask = data. source .∈ Ref (nodes)
128
166
source = [nodemap[i] for i in data. source[edge_mask]]
129
167
target = [nodemap[i] for i in data. target[edge_mask]]
130
168
edge_labels = isnothing (data. edge_labels) ? nothing : data. edge_labels[edge_mask]
131
169
132
- graph_labels = data. graph_labels[i]
170
+ graph_labels = isnothing (data . graph_labels) ? nothing : data. graph_labels[i]
133
171
134
172
node_attributes = isnothing (data. node_attributes) ? nothing : data. node_attributes[:,node_mask]
135
173
edge_attributes = isnothing (data. edge_attributes) ? nothing : data. edge_attributes[:,edge_mask]
136
174
graph_attributes = isnothing (data. graph_attributes) ? nothing : data. graph_attributes[:,i]
137
175
176
+ num_nodes = length (graph_indicator)
177
+ num_edges = length (source)
178
+ num_graphs = length (i)
138
179
139
- @assert source isa Vector
140
- @assert target isa Vector
141
- @assert node_labels isa Vector
142
- TUDataset (length (nodes), length (source), length (graph_labels),
180
+ TUDataset (num_nodes, num_edges, num_graphs,
143
181
source, target,
144
182
graph_indicator,
145
183
node_labels,
0 commit comments