Skip to content

Commit f70fa97

Browse files
tweak ogbdataset
1 parent 9face8b commit f70fa97

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

src/datasets/graphs/ogbdataset.jl

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ available for node prediction, edge prediction, or graph prediction tasks.
5050
## Node prediction tasks
5151
5252
```julia-repl
53-
julia> d = OGBDataset("ogbn-arxiv")
53+
julia> data = OGBDataset("ogbn-arxiv")
5454
dataset OGBDataset:
5555
name => ogbn-arxiv
5656
metadata => Dict{String, Any} with 16 entries
@@ -142,21 +142,20 @@ julia> labels
142142
0
143143
```
144144
"""
145-
struct OGBDataset{L} <: AbstractDataset
145+
struct OGBDataset{GD} <: AbstractDataset
146146
name::String
147147
metadata::Dict{String, Any}
148148
graphs::Vector{Graph}
149-
targets::L
150-
split_idx::NamedTuple
149+
graph_data::GD
151150
end
152151

153152
function OGBDataset(fullname; dir = nothing)
154153
metadata = read_ogb_metadata(fullname, dir)
155154
path = makedir_ogb(fullname, metadata["url"], dir)
156155
metadata["path"] = path
157-
graph_dicts, labels, split_idx = read_ogb_graph(path, metadata)
156+
graph_dicts, graph_data = read_ogb_graph(path, metadata)
158157
graphs = ogbdict2graph.(graph_dicts)
159-
return OGBDataset(fullname, metadata, graphs, labels, split_idx)
158+
return OGBDataset(fullname, metadata, graphs, graph_data)
160159
end
161160

162161
function read_ogb_metadata(fullname, dir = nothing)
@@ -175,6 +174,13 @@ function read_ogb_metadata(fullname, dir = nothing)
175174
df = read_csv(path_metadata)
176175
@assert fullname names(df)
177176
metadata = Dict{String, Any}(String(r[1]) => parse_pystring(r[2]) for r in eachrow(df[!,[names(df)[1], fullname]]))
177+
if prefix == "ogbn"
178+
metadata["task level"] = "node"
179+
elseif prefix == "ogbl"
180+
metadata["task level"] = "link"
181+
elseif prefix == "ogbg"
182+
metadata["task level"] = "graph"
183+
end
178184
return metadata
179185
end
180186

@@ -316,7 +322,40 @@ function read_ogb_graph(path, metadata)
316322
if split_idx.test !== nothing
317323
split_idx.test .+= 1
318324
end
319-
return graphs, labels, split_idx
325+
326+
327+
graph_data = nothing
328+
if metadata["task level"] == "node"
329+
@assert length(graphs) == 1
330+
g = graphs[1]
331+
if split_idx.train !== nothing
332+
g["node_train_mask"] = indexes2mask(split_idx.train, g["num_nodes"])
333+
end
334+
if split_idx.val !== nothing
335+
g["node_val_mask"] = indexes2mask(split_idx.val, g["num_nodes"])
336+
end
337+
if split_idx.test !== nothing
338+
g["node_test_mask"] = indexes2mask(split_idx.test, g["num_nodes"])
339+
end
340+
341+
end
342+
if metadata["task level"] == "link"
343+
@assert length(graphs) == 1
344+
g = graphs[1]
345+
if split_idx.train !== nothing
346+
g["edge_train_mask"] = indexes2mask(split_idx.train, g["num_edges"])
347+
end
348+
if split_idx.val !== nothing
349+
g["edge_val_mask"] = indexes2mask(split_idx.val, g["num_edges"])
350+
end
351+
if split_idx.test !== nothing
352+
g["edge_test_mask"] = indexes2mask(split_idx.test, g["num_edges"])
353+
end
354+
end
355+
if metadata["task level"] == "graph"
356+
graph_data = (; labels, split_idx)
357+
end
358+
return graphs, graph_data
320359
end
321360

322361
function read_ogb_file(p, T; tovec = false, transp = true)
@@ -346,7 +385,7 @@ function ogbdict2graph(d::Dict)
346385
end
347386

348387
Base.length(data::OGBDataset) = length(data.graphs)
349-
Base.getindex(data::OGBDataset{Nothing}, ::Colon) = data.graphs
350-
Base.getindex(data::OGBDataset, ::Colon) = (; data.graphs, data.targets)
388+
Base.getindex(data::OGBDataset{Nothing}, ::Colon) = length(data.graphs) == 1 ? data.graphs[1] : data.graphs
389+
Base.getindex(data::OGBDataset, ::Colon) = (; data.graphs, targets=data.graph_data.labels)
351390
Base.getindex(data::OGBDataset{Nothing}, i) = getobs(data.graphs, i)
352-
Base.getindex(data::OGBDataset, i) = getobs((; data.graphs, data.targets), i)
391+
Base.getindex(data::OGBDataset, i) = getobs((; data.graphs, targets=data.graph_data.labels), i)

test/datasets/graphs_no_ci.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,6 @@ end
104104
g = d[:]
105105
@test g.num_nodes == 169343
106106
@test g.num_edges == 1166243
107+
108+
@test sum(count.([g.node_data.train_mask, g.node_data.test_mask, g.node_data.val_mask])) == g.num_nodes
107109
end

0 commit comments

Comments
 (0)