Skip to content

Commit e4d0c24

Browse files
finish graphs
1 parent f70fa97 commit e4d0c24

File tree

7 files changed

+74
-97
lines changed

7 files changed

+74
-97
lines changed

src/abstract_datasets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ _summary(x) = x
4646
_summary(x::Symbol) = ":$x"
4747
_summary(x::Union{Dict, AbstractArray, DataFrame}) = summary(x)
4848
_summary(x::Union{Tuple, NamedTuple}) = map(_summary, x)
49-
_summary(x::BitVector) = summary(x) * " with $(count(x)) trues"
49+
_summary(x::BitVector) = "$(count(x))-trues BitVector"
5050

5151
"""
5252
SupervisedDataset <: AbstractDataset

src/datasets/graphs/ogbdataset.jl

Lines changed: 38 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -53,93 +53,61 @@ available for node prediction, edge prediction, or graph prediction tasks.
5353
julia> data = OGBDataset("ogbn-arxiv")
5454
dataset OGBDataset:
5555
name => ogbn-arxiv
56-
metadata => Dict{String, Any} with 16 entries
56+
metadata => Dict{String, Any} with 17 entries
5757
graphs => 1-element Vector{MLDatasets.Graph}
58-
targets => nothing
59-
split_idx => (train = "90941-element Vector{Int64}", val = "29799-element Vector{Int64}", test = "48603-element Vector{Int64}")
58+
graph_data => nothing
6059
6160
julia> data[:]
6261
Graph:
6362
num_nodes => 169343
6463
num_edges => 1166243
6564
edge_index => ("1166243-element Vector{Int64}", "1166243-element Vector{Int64}")
66-
node_data => (year = "169343 Matrix{Int64}", features = "128×169343 Matrix{Float32}", label = "169343 Matrix{Int64}")
65+
node_data => (val_mask = "29799-trues BitVector", test_mask = "48603-trues BitVector", year = "169343 Vector{Int64}", features = "128×169343 Matrix{Float32}", label = "169343 Vector{Int64}", train_mask = "90941-trues BitVector")
6766
edge_data => nothing
6867
6968
julia> data.metadata
70-
Dict{String, Any} with 16 entries:
69+
Dict{String, Any} with 17 entries:
7170
"download_name" => "arxiv"
7271
"num classes" => 40
7372
"num tasks" => 1
7473
"binary" => false
7574
"url" => "http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip"
7675
"additional node files" => "node_year"
7776
"is hetero" => false
78-
"path" => "/home/carlo/.julia/datadeps/OGBDataset/arxiv"
79-
"eval metric" => "acc"
80-
"task type" => "multiclass classification"
81-
"add_inverse_edge" => false
82-
"has_node_attr" => true
83-
"additional edge files" => nothing
84-
"version" => 1
85-
"has_edge_attr" => false
86-
"split" => "time"
77+
"task level" => "node"
78+
⋮ => ⋮
8779
```
8880
8981
## Edge prediction task
9082
9183
```julia-repl
9284
julia> data = OGBDataset("ogbl-collab")
93-
OGBDataset{Nothing}:
94-
name => ogbl-collab
95-
path => /home/carlo/.julia/datadeps/OGBDataset/collab
96-
metadata => Dict{String, Any} with 13 entries
97-
graphs => 1-element Vector{Dict}
98-
labels => nothing
99-
split => Dict{String, Any} with 3 entries
100-
101-
julia> graph = data[1] # no labels for this dataset
102-
Dict{String, Any} with 7 entries:
103-
"edge_index" => [150990 224882; 150990 224882; … ; 221742 135759; 207233 140615]
104-
"edge_feat" => nothing
105-
"node_feat" => Float32[-0.177486 -0.237488 … 0.004236 -0.035025; -0.10298 0.022193 … 0.031942 -0.118059; … ; 0.003879 0.062124 … 0.05208 -0.176961; -0.276317 -0.081464 … -0.201557 -0.258715]
106-
"num_nodes" => 235868
107-
"edge_year" => [2004 2002 … 2006 1984; 2004 2002 … 2006 1984]
108-
"edge_weight" => [2 1 … 1 1; 2 1 … 1 1]
109-
"num_edges" => 2358104
85+
dataset OGBDataset:
86+
name => ogbl-collab
87+
metadata => Dict{String, Any} with 15 entries
88+
graphs => 1-element Vector{MLDatasets.Graph}
89+
graph_data => nothing
90+
91+
julia> data[:]
92+
Graph:
93+
num_nodes => 235868
94+
num_edges => 2358104
95+
edge_index => ("2358104-element Vector{Int64}", "2358104-element Vector{Int64}")
96+
node_data => (features = "128×235868 Matrix{Float32}",)
97+
edge_data => (year = "2×1179052 Matrix{Int64}", weight = "2×1179052 Matrix{Int64}")
11098
```
11199
112100
## Graph prediction task
113101
114102
```julia-repl
115103
julia> data = OGBDataset("ogbg-molhiv")
116-
OGBDataset{Matrix{Int64}}:
117-
name => ogbg-molhiv
118-
path => /home/carlo/.julia/datadeps/OGBDataset/molhiv
119-
metadata => Dict{String, Any} with 15 entries
120-
graphs => 41127-element Vector{Dict}
121-
labels => 1×41127 Matrix{Int64}
122-
split => Dict{String, Any} with 3 entries
123-
124-
julia> length(data)
125-
41127
126-
127-
julia> graph, labels = data[10]
128-
(Dict{String, Any}("edge_index" => [-202 -201; -201 -200; … ; -198 -184; -201 -202], "node_feat" => Float32[7.0 6.0 … 7.0 7.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], "edge_feat" => Float32[0.0 0.0 … 0.0 1.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 1.0], "num_nodes" => 20, "num_edges" => 42), [0])
129-
130-
julia> graph, labels = data[10];
131-
132-
julia> graph
133-
Dict{String, Any} with 5 entries:
134-
"edge_index" => [1 2; 2 3; … ; 5 19; 2 1]
135-
"edge_feat" => Float32[0.0 0.0 … 0.0 1.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 1.0]
136-
"node_feat" => Float32[7.0 6.0 … 7.0 7.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
137-
"num_nodes" => 20
138-
"num_edges" => 42
139-
140-
julia> labels
141-
1-element Vector{Int64}:
142-
0
104+
dataset OGBDataset:
105+
name => ogbg-molhiv
106+
metadata => Dict{String, Any} with 17 entries
107+
graphs => 41127-element Vector{MLDatasets.Graph}
108+
graph_data => (labels = "41127-element Vector{Int64}", train_mask = "32901-trues BitVector", val_mask = "4113-trues BitVector", test_mask = "4113-trues BitVector")
109+
110+
julia> data[1]
143111
```
144112
"""
145113
struct OGBDataset{GD} <: AbstractDataset
@@ -309,10 +277,13 @@ function read_ogb_graph(path, metadata)
309277

310278
splits = readdir(joinpath(path, "split"))
311279
@assert length(splits) == 1 # TODO check if datasets with multiple splits existin in OGB
280+
312281
# TODO sometimes splits are given in .pt format
282+
# Use read_pytorch in src/io.jl to load them.
313283
split_idx = (train = read_ogb_file(joinpath(path, "split", splits[1], "train.csv"), Int; tovec=true),
314284
val = read_ogb_file(joinpath(path, "split", splits[1], "valid.csv"), Int; tovec=true),
315285
test = read_ogb_file(joinpath(path, "split", splits[1], "test.csv"), Int; tovec=true))
286+
316287
if split_idx.train !== nothing
317288
split_idx.train .+= 1
318289
end
@@ -353,7 +324,11 @@ function read_ogb_graph(path, metadata)
353324
end
354325
end
355326
if metadata["task level"] == "graph"
356-
graph_data = (; labels, split_idx)
327+
train_mask = split_idx.train !== nothing ? indexes2mask(split_idx.train, num_graphs) : nothing
328+
val_mask = split_idx.val !== nothing ? indexes2mask(split_idx.val, num_graphs) : nothing
329+
test_mask = split_idx.test !== nothing ? indexes2mask(split_idx.test, num_graphs) : nothing
330+
331+
graph_data = clean_nt((; labels=maybesqueeze(labels), train_mask, val_mask, test_mask))
357332
end
358333
return graphs, graph_data
359334
end
@@ -377,15 +352,15 @@ end
377352
function ogbdict2graph(d::Dict)
378353
edge_index = d["edge_index"][:,1], d["edge_index"][:,2]
379354
num_nodes = d["num_nodes"]
380-
node_data = Dict(Symbol(k[6:end]) => v for (k,v) in d if startswith(k, "node_") && v !== nothing)
381-
edge_data = Dict(Symbol(k[6:end]) => v for (k,v) in d if startswith(k, "edge_") && k!="edge_index" && v !== nothing)
355+
node_data = Dict(Symbol(k[6:end]) => maybesqueeze(v) for (k,v) in d if startswith(k, "node_") && v !== nothing)
356+
edge_data = Dict(Symbol(k[6:end]) => maybesqueeze(v) for (k,v) in d if startswith(k, "edge_") && k!="edge_index" && v !== nothing)
382357
node_data = isempty(node_data) ? nothing : (; node_data...)
383358
edge_data = isempty(edge_data) ? nothing : (; edge_data...)
384359
return Graph(; num_nodes, edge_index, node_data, edge_data)
385360
end
386361

387362
Base.length(data::OGBDataset) = length(data.graphs)
388363
Base.getindex(data::OGBDataset{Nothing}, ::Colon) = length(data.graphs) == 1 ? data.graphs[1] : data.graphs
389-
Base.getindex(data::OGBDataset, ::Colon) = (; data.graphs, targets=data.graph_data.labels)
364+
Base.getindex(data::OGBDataset, ::Colon) = (; data.graphs, data.graph_data.labels)
390365
Base.getindex(data::OGBDataset{Nothing}, i) = getobs(data.graphs, i)
391-
Base.getindex(data::OGBDataset, i) = getobs((; data.graphs, targets=data.graph_data.labels), i)
366+
Base.getindex(data::OGBDataset, i) = getobs((; data.graphs, data.graph_data.labels), i)

src/datasets/graphs/reddit.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function Reddit(; full=true, dir=nothing)
7272
nodes = graph["nodes"]
7373
num_edges = directed ? length(links) : length(links) * 2
7474
num_nodes = length(nodes)
75-
num_graphs = length(graph["graph"]) # should be zero
75+
@assert length(graph["graph"]) == 0 # should be zero
7676

7777
# edges
7878
s = get.(links, "source", nothing) .+ 1
@@ -101,22 +101,11 @@ function Reddit(; full=true, dir=nothing)
101101
@assert sum(val_mask .& test_mask) == 0
102102
train_mask = nor.(test_mask, val_mask)
103103

104-
train_idx = node_idx[train_mask]
105-
test_idx = node_idx[test_mask]
106-
val_idx = node_idx[val_mask]
107-
108-
split = Dict(
109-
"train" => train_idx,
110-
"test" => test_idx,
111-
"val" => val_idx
112-
)
113-
114104
metadata = Dict{String, Any}("directed" => directed, "multigraph" => multigraph,
115-
"num_graphs" => num_graphs, "num_edges" => num_edges, "num_nodes" => num_nodes,
116-
"split" => split)
105+
"num_edges" => num_edges, "num_nodes" => num_nodes)
117106
g = Graph(; num_nodes,
118107
edge_index=(s, t),
119-
node_data= (; labels, features)
108+
node_data= (; labels, features, train_mask, val_mask, test_mask)
120109
)
121110
return Reddit(metadata, [g])
122111
end

src/datasets/graphs/tudataset.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ end
2121
2222
A variety of graph benchmark datasets, *.e.g.* "QM9", "IMDB-BINARY",
2323
"REDDIT-BINARY" or "PROTEINS", collected from the [TU Dortmund University](https://chrsmrrs.github.io/datasets/).
24-
Retrieve from TUDataset collection the dataset `name`, where `name`
24+
Retrieve from the TUDataset collection the dataset `name`, where `name`
2525
is any of the datasets available [here](https://chrsmrrs.github.io/datasets/docs/datasets/).
2626
2727
A `TUDataset` object can be indexed to retrieve a specific graph or a subset of graphs.
@@ -31,16 +31,19 @@ description of the format.
3131
3232
# Usage Example
3333
34-
```julia
35-
using MLDatasets: TUDataset
36-
37-
data = TUDataset("PROTEINS")
38-
39-
# Access first graph
40-
d1 = data[1]
41-
42-
# Node features
43-
X = d1.node_attributes # (nfeatures x nnodes) matrix
34+
```julia-repl
35+
julia> data = TUDataset("PROTEINS")
36+
dataset TUDataset:
37+
name => PROTEINS
38+
metadata => Dict{String, Any} with 1 entry
39+
graphs => 1113-element Vector{MLDatasets.Graph}
40+
graph_data => (targets = "1113-element Vector{Int64}",)
41+
num_nodes => 43471
42+
num_edges => 162088
43+
num_graphs => 1113
44+
45+
julia> data[1]
46+
(graphs = Graph(42, 162), targets = 1)
4447
```
4548
"""
4649
struct TUDataset <: AbstractDataset

src/io.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ end
2020
function read_npz(path)
2121
return NPZ.npzread(path)
2222
end
23+
24+
function read_pytorch(path)
25+
return Pickle.Torch.THload(path)
26+
end

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ function mask2indexes(mask::BitVector)
6060
return (1:n)[mask]
6161
end
6262

63+
maybesqueeze(x) = x
64+
maybesqueeze(x::AbstractMatrix) = size(x, 1) == 1 ? vec(x) : x
6365

6466
"""
6567
convert2image(d, i)

test/datasets/graphs_no_ci.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
@test g.num_edges == 114615892
88
@test size(g.node_data.features) == (602, g.num_nodes)
99
@test size(g.node_data.labels) == (g.num_nodes,)
10-
@test size(data.metadata["split"]["train"]) == (153431,)
11-
@test size(data.metadata["split"]["val"]) == (23831,)
12-
@test size(data.metadata["split"]["test"]) == (55703,)
10+
@test count(g.node_data.train_mask) == 153431
11+
@test count(g.node_data.val_mask) == 23831
12+
@test count(g.node_data.test_mask) == 55703
1313
s, t = g.edge_index
1414
@test length(s) == length(t) == g.num_edges
1515
@test minimum(s) == minimum(t) == 1
1616
@test maximum(s) == maximum(t) == g.num_nodes
17-
@test sum(length.(values(data.metadata["split"]))) == g.num_nodes
1817
end
1918

2019
@testset "Reddit_subset" begin
@@ -25,14 +24,13 @@ end
2524
@test g.num_edges == 23213838
2625
@test size(g.node_data.features) == (602, g.num_nodes)
2726
@test size(g.node_data.labels) == (g.num_nodes,)
28-
@test size(data.metadata["split"]["train"]) == (152410,)
29-
@test size(data.metadata["split"]["val"]) == (23699,)
30-
@test size(data.metadata["split"]["test"]) == (55334,)
27+
@test count(g.node_data.train_mask) == 152410
28+
@test count(g.node_data.val_mask) == 23699
29+
@test count(g.node_data.test_mask) == 55334
3130
s, t = g.edge_index
3231
@test length(s) == length(t) == g.num_edges
3332
@test minimum(s) == minimum(t) == 1
3433
@test maximum(s) == maximum(t) == g.num_nodes
35-
@test sum(length.(values(data.metadata["split"]))) == g.num_nodes
3634
end
3735

3836

@@ -107,3 +105,9 @@ end
107105

108106
@test sum(count.([g.node_data.train_mask, g.node_data.test_mask, g.node_data.val_mask])) == g.num_nodes
109107
end
108+
109+
@testset "OGBDataset - ogbg-molhiv" begin
110+
d = OGBDataset("ogbg-molhiv")
111+
112+
@test sum(count.([d.graph_data.train_mask, d.graph_data.test_mask, d.graph_data.val_mask])) == length(d)
113+
end

0 commit comments

Comments
 (0)