@@ -53,93 +53,61 @@ available for node prediction, edge prediction, or graph prediction tasks.
53
53
julia> data = OGBDataset("ogbn-arxiv")
54
54
dataset OGBDataset:
55
55
name => ogbn-arxiv
56
- metadata => Dict{String, Any} with 16 entries
56
+ metadata => Dict{String, Any} with 17 entries
57
57
graphs => 1-element Vector{MLDatasets.Graph}
58
- targets => nothing
59
- split_idx => (train = "90941-element Vector{Int64}", val = "29799-element Vector{Int64}", test = "48603-element Vector{Int64}")
58
+ graph_data => nothing
60
59
61
60
julia> data[:]
62
61
Graph:
63
62
num_nodes => 169343
64
63
num_edges => 1166243
65
64
edge_index => ("1166243-element Vector{Int64}", "1166243-element Vector{Int64}")
66
- node_data => (year = "1× 169343 Matrix {Int64}", features = "128×169343 Matrix{Float32}", label = "1× 169343 Matrix {Int64}")
65
+ node_data => (val_mask = "29799-trues BitVector", test_mask = "48603-trues BitVector", year = "169343 Vector {Int64}", features = "128×169343 Matrix{Float32}", label = "169343 Vector {Int64}", train_mask = "90941-trues BitVector ")
67
66
edge_data => nothing
68
67
69
68
julia> data.metadata
70
- Dict{String, Any} with 16 entries:
69
+ Dict{String, Any} with 17 entries:
71
70
"download_name" => "arxiv"
72
71
"num classes" => 40
73
72
"num tasks" => 1
74
73
"binary" => false
75
74
"url" => "http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip"
76
75
"additional node files" => "node_year"
77
76
"is hetero" => false
78
- "path" => "/home/carlo/.julia/datadeps/OGBDataset/arxiv"
79
- "eval metric" => "acc"
80
- "task type" => "multiclass classification"
81
- "add_inverse_edge" => false
82
- "has_node_attr" => true
83
- "additional edge files" => nothing
84
- "version" => 1
85
- "has_edge_attr" => false
86
- "split" => "time"
77
+ "task level" => "node"
78
+ ⋮ => ⋮
87
79
```
88
80
89
81
## Edge prediction task
90
82
91
83
```julia-repl
92
84
julia> data = OGBDataset("ogbl-collab")
93
- OGBDataset{Nothing}:
94
- name => ogbl-collab
95
- path => /home/carlo/.julia/datadeps/OGBDataset/collab
96
- metadata => Dict{String, Any} with 13 entries
97
- graphs => 1-element Vector{Dict}
98
- labels => nothing
99
- split => Dict{String, Any} with 3 entries
100
-
101
- julia> graph = data[1] # no labels for this dataset
102
- Dict{String, Any} with 7 entries:
103
- "edge_index" => [150990 224882; 150990 224882; … ; 221742 135759; 207233 140615]
104
- "edge_feat" => nothing
105
- "node_feat" => Float32[-0.177486 -0.237488 … 0.004236 -0.035025; -0.10298 0.022193 … 0.031942 -0.118059; … ; 0.003879 0.062124 … 0.05208 -0.176961; -0.276317 -0.081464 … -0.201557 -0.258715]
106
- "num_nodes" => 235868
107
- "edge_year" => [2004 2002 … 2006 1984; 2004 2002 … 2006 1984]
108
- "edge_weight" => [2 1 … 1 1; 2 1 … 1 1]
109
- "num_edges" => 2358104
85
+ dataset OGBDataset:
86
+ name => ogbl-collab
87
+ metadata => Dict{String, Any} with 15 entries
88
+ graphs => 1-element Vector{MLDatasets.Graph}
89
+ graph_data => nothing
90
+
91
+ julia> data[:]
92
+ Graph:
93
+ num_nodes => 235868
94
+ num_edges => 2358104
95
+ edge_index => ("2358104-element Vector{Int64}", "2358104-element Vector{Int64}")
96
+ node_data => (features = "128×235868 Matrix{Float32}",)
97
+ edge_data => (year = "2×1179052 Matrix{Int64}", weight = "2×1179052 Matrix{Int64}")
110
98
```
111
99
112
100
## Graph prediction task
113
101
114
102
```julia-repl
115
103
julia> data = OGBDataset("ogbg-molhiv")
116
- OGBDataset{Matrix{Int64}}:
117
- name => ogbg-molhiv
118
- path => /home/carlo/.julia/datadeps/OGBDataset/molhiv
119
- metadata => Dict{String, Any} with 15 entries
120
- graphs => 41127-element Vector{Dict}
121
- labels => 1×41127 Matrix{Int64}
122
- split => Dict{String, Any} with 3 entries
123
-
124
- julia> length(data)
125
- 41127
126
-
127
- julia> graph, labels = data[10]
128
- (Dict{String, Any}("edge_index" => [-202 -201; -201 -200; … ; -198 -184; -201 -202], "node_feat" => Float32[7.0 6.0 … 7.0 7.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], "edge_feat" => Float32[0.0 0.0 … 0.0 1.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 1.0], "num_nodes" => 20, "num_edges" => 42), [0])
129
-
130
- julia> graph, labels = data[10];
131
-
132
- julia> graph
133
- Dict{String, Any} with 5 entries:
134
- "edge_index" => [1 2; 2 3; … ; 5 19; 2 1]
135
- "edge_feat" => Float32[0.0 0.0 … 0.0 1.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 1.0]
136
- "node_feat" => Float32[7.0 6.0 … 7.0 7.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
137
- "num_nodes" => 20
138
- "num_edges" => 42
139
-
140
- julia> labels
141
- 1-element Vector{Int64}:
142
- 0
104
+ dataset OGBDataset:
105
+ name => ogbg-molhiv
106
+ metadata => Dict{String, Any} with 17 entries
107
+ graphs => 41127-element Vector{MLDatasets.Graph}
108
+ graph_data => (labels = "41127-element Vector{Int64}", train_mask = "32901-trues BitVector", val_mask = "4113-trues BitVector", test_mask = "4113-trues BitVector")
109
+
110
+ julia> data[1]
143
111
```
144
112
"""
145
113
struct OGBDataset{GD} <: AbstractDataset
@@ -309,10 +277,13 @@ function read_ogb_graph(path, metadata)
309
277
310
278
splits = readdir (joinpath (path, " split" ))
311
279
@assert length (splits) == 1 # TODO check if datasets with multiple splits existin in OGB
280
+
312
281
# TODO sometimes splits are given in .pt format
282
+ # Use read_pytorch in src/io.jl to load them.
313
283
split_idx = (train = read_ogb_file (joinpath (path, " split" , splits[1 ], " train.csv" ), Int; tovec= true ),
314
284
val = read_ogb_file (joinpath (path, " split" , splits[1 ], " valid.csv" ), Int; tovec= true ),
315
285
test = read_ogb_file (joinpath (path, " split" , splits[1 ], " test.csv" ), Int; tovec= true ))
286
+
316
287
if split_idx. train != = nothing
317
288
split_idx. train .+ = 1
318
289
end
@@ -353,7 +324,11 @@ function read_ogb_graph(path, metadata)
353
324
end
354
325
end
355
326
if metadata[" task level" ] == " graph"
356
- graph_data = (; labels, split_idx)
327
+ train_mask = split_idx. train != = nothing ? indexes2mask (split_idx. train, num_graphs) : nothing
328
+ val_mask = split_idx. val != = nothing ? indexes2mask (split_idx. val, num_graphs) : nothing
329
+ test_mask = split_idx. test != = nothing ? indexes2mask (split_idx. test, num_graphs) : nothing
330
+
331
+ graph_data = clean_nt ((; labels= maybesqueeze (labels), train_mask, val_mask, test_mask))
357
332
end
358
333
return graphs, graph_data
359
334
end
@@ -377,15 +352,15 @@ end
377
352
function ogbdict2graph (d:: Dict )
378
353
edge_index = d[" edge_index" ][:,1 ], d[" edge_index" ][:,2 ]
379
354
num_nodes = d[" num_nodes" ]
380
- node_data = Dict (Symbol (k[6 : end ]) => v for (k,v) in d if startswith (k, " node_" ) && v != = nothing )
381
- edge_data = Dict (Symbol (k[6 : end ]) => v for (k,v) in d if startswith (k, " edge_" ) && k!= " edge_index" && v != = nothing )
355
+ node_data = Dict (Symbol (k[6 : end ]) => maybesqueeze (v) for (k,v) in d if startswith (k, " node_" ) && v != = nothing )
356
+ edge_data = Dict (Symbol (k[6 : end ]) => maybesqueeze (v) for (k,v) in d if startswith (k, " edge_" ) && k!= " edge_index" && v != = nothing )
382
357
node_data = isempty (node_data) ? nothing : (; node_data... )
383
358
edge_data = isempty (edge_data) ? nothing : (; edge_data... )
384
359
return Graph (; num_nodes, edge_index, node_data, edge_data)
385
360
end
386
361
387
362
Base. length (data:: OGBDataset ) = length (data. graphs)
388
363
Base. getindex (data:: OGBDataset{Nothing} , :: Colon ) = length (data. graphs) == 1 ? data. graphs[1 ] : data. graphs
389
- Base. getindex (data:: OGBDataset , :: Colon ) = (; data. graphs, targets = data. graph_data. labels)
364
+ Base. getindex (data:: OGBDataset , :: Colon ) = (; data. graphs, data. graph_data. labels)
390
365
Base. getindex (data:: OGBDataset{Nothing} , i) = getobs (data. graphs, i)
391
- Base. getindex (data:: OGBDataset , i) = getobs ((; data. graphs, targets = data. graph_data. labels), i)
366
+ Base. getindex (data:: OGBDataset , i) = getobs ((; data. graphs, data. graph_data. labels), i)
0 commit comments