Skip to content

Commit ece027c

Browse files
working implementation
1 parent f59d191 commit ece027c

File tree

7 files changed

+132
-37
lines changed

7 files changed

+132
-37
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,10 @@ Find below a list of available datasets and links to their documentation.
2727
- [MNIST](https://juliaml.github.io/MLDatasets.jl/dev/datasets/MNIST/)
2828
- [SVHN2](https://juliaml.github.io/MLDatasets.jl/dev/datasets/SVHN2/)
2929

30-
3130
#### Miscellaneous
3231
- [BostonHousing](https://juliaml.github.io/MLDatasets.jl/dev/datasets/BostonHousing/)
3332
- [Iris](https://juliaml.github.io/MLDatasets.jl/dev/datasets/Iris/)
3433

35-
3634
#### Text
3735
- [PTBLM](https://juliaml.github.io/MLDatasets.jl/dev/datasets/PTBLM/)
3836
- [UD_English](https://juliaml.github.io/MLDatasets.jl/dev/datasets/UD_English/)
@@ -41,7 +39,7 @@ Find below a list of available datasets and links to their documentation.
4139
- [CiteSeer](https://juliaml.github.io/MLDatasets.jl/dev/datasets/CiteSeer/)
4240
- [Cora](https://juliaml.github.io/MLDatasets.jl/dev/datasets/Cora/)
4341
- [PubMed](https://juliaml.github.io/MLDatasets.jl/dev/datasets/PubMed/)
44-
42+
- [TUDatasets](TODO)
4543

4644

4745
## Installation

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ makedocs(
4444
"CiteSeer" => "datasets/CiteSeer.md",
4545
"Cora" => "datasets/Cora.md",
4646
"PubMed" => "datasets/PubMed.md",
47+
"TUDataset" => "datasets/TUDataset.md",
4748
],
4849

4950
],

docs/src/datasets/TUDataset.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# TUDataset
2+
3+
```@docs
4+
TUDataset
5+
```
6+
7+
## API reference
8+
9+
```@docs
10+
TUDataset.dataset
11+
```

src/MLDatasets.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ function __init__()
7272
out = out.todense() if hasattr(out, 'todense') else out
7373
return out
7474
"""
75+
76+
__init__tudataset()
7577
end
7678

7779
end

src/TUDataset/TUDataset.jl

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,84 @@
11
export TUDataset
22

3-
"""
4-
TUDataset
5-
6-
A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY",
7-
"REDDIT-BINARY" or "PROTEINS", collected from the [TU Dortmund University](https://chrsmrrs.github.io/datasets).
8-
"""
9-
module TUDataset
10-
113
using DataDeps
12-
using ..MLDatasets: datafile, datadir
4+
# using ..MLDatasets: datafile, datadir
135
using DelimitedFiles: readdlm
146

15-
using PyCall
16-
17-
const DEPNAME = "TUDataset"
18-
# LINK = "https://github.com/shchur/gnn-benchmark/raw/master/data/npz"
19-
# LINK = "https://github.com/abojchevski/graph2gauss/raw/master/data/"
20-
const LINK = "https://www.chrsmrrs.com/graphkerneldatasets"
21-
const DOCS = "https://chrsmrrs.github.io/datasets"
22-
const DATA = "PROTEINS.zip"
7+
function __init__tudataset()
8+
DEPNAME = "TUDataset"
9+
LINK = "https://www.chrsmrrs.com/graphkerneldatasets"
10+
DOCS = ""
11+
DATA = "PROTEINS.zip"
2312

24-
function __init__()
2513
register(DataDep(
2614
DEPNAME,
2715
"""
2816
Dataset: The $DEPNAME dataset.
29-
Website: $DOCS
17+
Website: $LINK)
3018
""",
3119
"$LINK/$DATA",
3220
# "81de017067dc045ebdb8ffd5c0e69a209973ffdb1fe2d5b434e94d3614f3f5c7", # if checksum omitted, will be generated by DataDeps
3321
post_fetch_method = unpack
3422
))
3523
end
3624

37-
struct TUData
25+
struct TUDataset
26+
num_nodes::Int
27+
num_edges::Int
28+
num_graphs::Int
3829
source::Vector{Int}
3930
target::Vector{Int}
40-
graph_indicator::Vector{Int}
31+
graph_indicator
4132
node_labels::Vector{Int}
4233
edge_labels::Union{Nothing, Vector{Int}}
43-
graph_labels::Vector{Int}
34+
graph_labels
4435
node_attributes
4536
edge_attributes
4637
graph_attributes
4738
end
4839

4940
"""
41+
TUDataset
42+
43+
A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY",
44+
"REDDIT-BINARY" or "PROTEINS", collected from the [TU Dortmund University](https://chrsmrrs.github.io/datasets).
45+
5046
dataset(name; dir=nothing)
5147
52-
Retrieve the TUDataset dataset. The output is a named tuple with fields
48+
Retrieve the TUDataset dataset. The output is an object with fields
49+
50+
```
51+
num_nodes
52+
num_edges
53+
num_graphs
54+
source # vector of edges' source vectors
55+
target # vector of edges' target vectors
56+
graph_indicator # graph
57+
node_labels
58+
edge_labels
59+
graph_labels
60+
node_attributes
61+
edge_attributes
62+
graph_attributes
63+
```
5364
5465
See [this link](https://chrsmrrs.github.io/datasets/docs/datasets/)
5566
for a list of the available datasets.
5667
"""
57-
function dataset(name; dir=nothing)
58-
d = datadir(DEPNAME, dir)
68+
function TUDataset(name; dir=nothing)
69+
d = datadir("TUDataset", dir)
5970
# See here for the file format https://chrsmrrs.github.io/datasets/docs/format/
6071
st = readdlm(joinpath(d, name, "$(name)_A.txt"), ',', Int)
72+
73+
# Check that the first node is labeled 1.
74+
# TODO this will fail if the first node is isolated
75+
@assert minimum(st) == 1
76+
77+
graph_indicator = readdlm(joinpath(d, name, "$(name)_graph_indicator.txt"), Int) |> vec
78+
@assert all(sort(unique(graph_indicator)) .== 1:length(unique(graph_indicator)))
79+
80+
node_labels = readdlm(joinpath(d, name, "$(name)_node_labels.txt"), Int) |> vec
81+
graph_labels = readdlm(joinpath(d, name, "$(name)_graph_labels.txt"), Int) |> vec
6182

6283
# LOAD OPTIONAL FILES IF EXIST
6384

@@ -82,16 +103,49 @@ function dataset(name; dir=nothing)
82103
graph_attributes = nothing
83104
end
84105

85-
TUData(st[:,1], st[:,2],
86-
readdlm(joinpath(d, name, "$(name)_graph_indicator.txt"), Int) |> vec,
87-
readdlm(joinpath(d, name, "$(name)_node_labels.txt"), Int) |> vec,
106+
107+
TUDataset( length(node_labels), size(st, 1), length(graph_labels),
108+
st[:,1], st[:,2],
109+
graph_indicator,
110+
node_labels,
111+
edge_labels,
112+
graph_labels,
113+
node_attributes,
114+
edge_attributes,
115+
graph_attributes)
116+
end
117+
118+
119+
function Base.getindex(data::TUDataset, i)
120+
node_mask = data.graph_indicator .∈ Ref(i)
121+
graph_indicator = data.graph_indicator[node_mask]
122+
123+
nodes = (1:data.num_nodes)[node_mask]
124+
node_labels = data.node_labels[node_mask]
125+
nodemap = Dict(v => i for (i, v) in enumerate(nodes))
126+
127+
edge_mask = data.source .∈ Ref(nodes)
128+
source = [nodemap[i] for i in data.source[edge_mask]]
129+
target = [nodemap[i] for i in data.target[edge_mask]]
130+
edge_labels = isnothing(data.edge_labels) ? nothing : data.edge_labels[edge_mask]
131+
132+
graph_labels = data.graph_labels[i]
133+
134+
node_attributes = isnothing(data.node_attributes) ? nothing : data.node_attributes[:,node_mask]
135+
edge_attributes = isnothing(data.edge_attributes) ? nothing : data.edge_attributes[:,edge_mask]
136+
graph_attributes = isnothing(data.graph_attributes) ? nothing : data.graph_attributes[:,i]
137+
138+
139+
@assert source isa Vector
140+
@assert target isa Vector
141+
@assert node_labels isa Vector
142+
TUDataset(length(nodes), length(source), length(graph_labels),
143+
source, target,
144+
graph_indicator,
145+
node_labels,
88146
edge_labels,
89-
readdlm(joinpath(d, name, "$(name)_graph_labels.txt"), Int) |> vec,
147+
graph_labels,
90148
node_attributes,
91149
edge_attributes,
92150
graph_attributes)
93151
end
94-
95-
96-
end #module
97-

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,21 @@ using DataDeps
77
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
88

99
tests = [
10+
# misc
1011
"tst_iris.jl",
1112
"tst_boston_housing.jl",
13+
# vision
1214
"tst_cifar10.jl",
1315
"tst_cifar100.jl",
1416
"tst_mnist.jl",
1517
"tst_fashion_mnist.jl",
1618
"tst_svhn2.jl",
1719
"tst_emnist.jl",
18-
"tst_cora.jl",
20+
# graphs
1921
"tst_citeseer.jl",
22+
"tst_cora.jl",
2023
"tst_pubmed.jl",
24+
"tst_tudataset.jl",
2125
]
2226

2327
for t in tests

test/tst_tudataset.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
data_dir = withenv("DATADEPS_ALWAY_ACCEPT"=>"true") do
2+
datadep"TUDataset"
3+
end
4+
5+
@testset "TUDataset - PROTEINS" begin
6+
data = TUDataset("PROTEINS")
7+
8+
@test data.num_nodes == 43471
9+
@test data.num_edges == 162088
10+
@test data.num_graphs === 1113
11+
12+
@test length(data.source) == data.num_edges
13+
@test length(data.target) == data.num_edges
14+
15+
@test size(data.node_attributes) == (1, data.num_nodes)
16+
@test data.edge_attributes === nothing
17+
@test data.graph_attributes === nothing
18+
19+
@test size(data.node_labels) == (data.num_nodes,)
20+
@test data.edge_labels === nothing
21+
@test size(data.graph_labels) == (data.num_graphs,)
22+
23+
@test length(data.graph_indicator) == data.num_nodes
24+
@test all(sort(unique(data.graph_indicator)) .== 1:data.num_graphs)
25+
end

0 commit comments

Comments
 (0)