@@ -50,7 +50,7 @@ available for node prediction, edge prediction, or graph prediction tasks.
50
50
## Node prediction tasks
51
51
52
52
```julia-repl
53
- julia> d = OGBDataset("ogbn-arxiv")
53
+ julia> data = OGBDataset("ogbn-arxiv")
54
54
dataset OGBDataset:
55
55
name => ogbn-arxiv
56
56
metadata => Dict{String, Any} with 16 entries
@@ -142,21 +142,20 @@ julia> labels
142
142
0
143
143
```
144
144
"""
145
- struct OGBDataset{L } <: AbstractDataset
145
+ struct OGBDataset{GD } <: AbstractDataset
146
146
name:: String
147
147
metadata:: Dict{String, Any}
148
148
graphs:: Vector{Graph}
149
- targets:: L
150
- split_idx:: NamedTuple
149
+ graph_data:: GD
151
150
end
152
151
153
152
function OGBDataset (fullname; dir = nothing )
154
153
metadata = read_ogb_metadata (fullname, dir)
155
154
path = makedir_ogb (fullname, metadata[" url" ], dir)
156
155
metadata[" path" ] = path
157
- graph_dicts, labels, split_idx = read_ogb_graph (path, metadata)
156
+ graph_dicts, graph_data = read_ogb_graph (path, metadata)
158
157
graphs = ogbdict2graph .(graph_dicts)
159
- return OGBDataset (fullname, metadata, graphs, labels, split_idx )
158
+ return OGBDataset (fullname, metadata, graphs, graph_data )
160
159
end
161
160
162
161
function read_ogb_metadata (fullname, dir = nothing )
@@ -175,6 +174,13 @@ function read_ogb_metadata(fullname, dir = nothing)
175
174
df = read_csv (path_metadata)
176
175
@assert fullname ∈ names (df)
177
176
metadata = Dict {String, Any} (String (r[1 ]) => parse_pystring (r[2 ]) for r in eachrow (df[! ,[names (df)[1 ], fullname]]))
177
+ if prefix == " ogbn"
178
+ metadata[" task level" ] = " node"
179
+ elseif prefix == " ogbl"
180
+ metadata[" task level" ] = " link"
181
+ elseif prefix == " ogbg"
182
+ metadata[" task level" ] = " graph"
183
+ end
178
184
return metadata
179
185
end
180
186
@@ -316,7 +322,40 @@ function read_ogb_graph(path, metadata)
316
322
if split_idx. test != = nothing
317
323
split_idx. test .+ = 1
318
324
end
319
- return graphs, labels, split_idx
325
+
326
+
327
+ graph_data = nothing
328
+ if metadata[" task level" ] == " node"
329
+ @assert length (graphs) == 1
330
+ g = graphs[1 ]
331
+ if split_idx. train != = nothing
332
+ g[" node_train_mask" ] = indexes2mask (split_idx. train, g[" num_nodes" ])
333
+ end
334
+ if split_idx. val != = nothing
335
+ g[" node_val_mask" ] = indexes2mask (split_idx. val, g[" num_nodes" ])
336
+ end
337
+ if split_idx. test != = nothing
338
+ g[" node_test_mask" ] = indexes2mask (split_idx. test, g[" num_nodes" ])
339
+ end
340
+
341
+ end
342
+ if metadata[" task level" ] == " link"
343
+ @assert length (graphs) == 1
344
+ g = graphs[1 ]
345
+ if split_idx. train != = nothing
346
+ g[" edge_train_mask" ] = indexes2mask (split_idx. train, g[" num_edges" ])
347
+ end
348
+ if split_idx. val != = nothing
349
+ g[" edge_val_mask" ] = indexes2mask (split_idx. val, g[" num_edges" ])
350
+ end
351
+ if split_idx. test != = nothing
352
+ g[" edge_test_mask" ] = indexes2mask (split_idx. test, g[" num_edges" ])
353
+ end
354
+ end
355
+ if metadata[" task level" ] == " graph"
356
+ graph_data = (; labels, split_idx)
357
+ end
358
+ return graphs, graph_data
320
359
end
321
360
322
361
function read_ogb_file (p, T; tovec = false , transp = true )
@@ -346,7 +385,7 @@ function ogbdict2graph(d::Dict)
346
385
end
347
386
348
387
Base. length (data:: OGBDataset ) = length (data. graphs)
349
- Base. getindex (data:: OGBDataset{Nothing} , :: Colon ) = data. graphs
350
- Base. getindex (data:: OGBDataset , :: Colon ) = (; data. graphs, data. targets )
388
+ Base. getindex (data:: OGBDataset{Nothing} , :: Colon ) = length (data . graphs) == 1 ? data . graphs[ 1 ] : data. graphs
389
+ Base. getindex (data:: OGBDataset , :: Colon ) = (; data. graphs, targets = data. graph_data . labels )
351
390
Base. getindex (data:: OGBDataset{Nothing} , i) = getobs (data. graphs, i)
352
- Base. getindex (data:: OGBDataset , i) = getobs ((; data. graphs, data. targets ), i)
391
+ Base. getindex (data:: OGBDataset , i) = getobs ((; data. graphs, targets = data. graph_data . labels ), i)
0 commit comments