Skip to content

Commit 9face8b

Browse files
[] -> [:]
1 parent 5bcae79 commit 9face8b

26 files changed

+106
-117
lines changed

docs/src/datasets/graphs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
A collection of datasets with an underlying graph structure.
44
Some of these datasets contain a single graph, that can be accessed
5-
with `dataset[]` or `dataset[1]`. Others contain many graphs,
5+
with `dataset[:]` or `dataset[1]`. Others contain many graphs,
66
accessed through `dataset[i]`. Graphs are represented by the [`MLDatasets.Graph`](@ref) type.
77

88
## Index

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Where possible, those types share a common interface (fields and methods).
3838

3939
Once a dataset has been instantiated, e.g. by `dataset = MNIST()`,
4040
an observation `i` can be retrieved using the indexing syntax `dataset[i]`.
41-
By indexing with no arguments, `dataset[]`, the whole set of observations is collected.
41+
By indexing with no arguments, `dataset[:]`, the whole set of observations is collected.
4242
The total number of observations is given by `length(dataset)`.
4343

4444
For example you can load the training set of the [`MNIST`](@ref)
@@ -60,7 +60,7 @@ julia> trainset[1] # return first observation as a NamedTuple
6060
(features = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0],
6161
targets = 5)
6262
63-
julia> X_train, y_train = trainset[] # return all observations
63+
julia> X_train, y_train = trainset[:] # return all observations
6464
(features = [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0],
6565
targets = [5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8])
6666

src/abstract_datasets.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
Super-type from which all datasets in MLDatasets.jl inherit.
55
66
Implements the following functionality:
7-
- `getobs(d)` and `getobs(d, i)` falling back to `d[]` and `d[i]`
7+
- `getobs(d)` and `getobs(d, i)` falling back to `d[:]` and `d[i]`
88
- Pretty printing.
99
"""
1010
abstract type AbstractDataset <: AbstractDataContainer end
1111

1212

13-
MLUtils.getobs(d::AbstractDataset) = d[]
13+
MLUtils.getobs(d::AbstractDataset) = d[:]
1414
MLUtils.getobs(d::AbstractDataset, i) = d[i]
1515

1616
function Base.show(io::IO, d::D) where D <: AbstractDataset
@@ -58,11 +58,10 @@ a `features` and a `targets` fields.
5858
abstract type SupervisedDataset <: AbstractDataset end
5959

6060

61-
6261
Base.length(d::SupervisedDataset) = numobs((d.features, d.targets))
6362

6463
# We return named tuples
65-
Base.getindex(d::SupervisedDataset) = getobs((; d.features, d.targets))
64+
Base.getindex(d::SupervisedDataset, ::Colon) = getobs((; d.features, d.targets))
6665
Base.getindex(d::SupervisedDataset, i) = getobs((; d.features, d.targets), i)
6766

6867
"""
@@ -76,7 +75,7 @@ abstract type UnsupervisedDataset <: AbstractDataset end
7675

7776
Base.length(d::UnsupervisedDataset) = numobs(d.features)
7877

79-
Base.getindex(d::UnsupervisedDataset) = getobs(d.features)
78+
Base.getindex(d::UnsupervisedDataset, ::Colon) = getobs(d.features)
8079
Base.getindex(d::UnsupervisedDataset, i) = getobs(d.features, i)
8180

8281

@@ -98,7 +97,7 @@ const FIELDS_SUPERVISED_TABLE = """
9897

9998
const METHODS_SUPERVISED_TABLE = """
10099
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets.
101-
- `dataset[]`: Return all observations as a named tuple of features and targets.
100+
- `dataset[:]`: Return all observations as a named tuple of features and targets.
102101
- `length(dataset)`: Number of observations.
103102
"""
104103

@@ -117,6 +116,6 @@ const FIELDS_SUPERVISED_ARRAY = """
117116

118117
const METHODS_SUPERVISED_ARRAY = """
119118
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets.
120-
- `dataset[]`: Return all observations as a named tuple of features and targets.
119+
- `dataset[:]`: Return all observations as a named tuple of features and targets.
121120
- `length(dataset)`: Number of observations.
122121
"""

src/datasets/graphs/citeseer.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ function CiteSeer(; dir=nothing, reverse_edges=true)
4242
end
4343

4444
Base.length(d::CiteSeer) = length(d.graphs)
45-
Base.getindex(d::CiteSeer) = d.graphs[1]
46-
Base.getindex(d::CiteSeer, i) = getindex(d.graphs, i)
47-
45+
Base.getindex(d::CiteSeer, ::Colon) = d.graphs[1]
46+
Base.getindex(d::CiteSeer, i) = d.graphs[i]
4847

4948

5049
# DEPRECATED in v0.6.0

src/datasets/graphs/cora.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function Cora(; dir=nothing, reverse_edges=true)
6060
end
6161

6262
Base.length(d::Cora) = length(d.graphs)
63-
Base.getindex(d::Cora) = d.graphs[1]
63+
Base.getindex(d::Cora, ::Colon) = d.graphs[1]
6464
Base.getindex(d::Cora, i) = getindex(d.graphs, i)
6565

6666

src/datasets/graphs/karateclub.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
export KarateClub
22

33
"""
4-
Zachary's Karate Club
4+
KarateClub()
55
6-
The Karate Club dataset originally appeared in Ref [1].
6+
The Zachary's karate club dataset originally appeared in Ref [1].
77
88
The network contains 34 nodes (members of the karate club).
99
The nodes are connected by 78 undirected and unweighted edges.
@@ -69,5 +69,5 @@ function KarateClub()
6969
end
7070

7171
Base.length(d::KarateClub) = length(d.graphs)
72-
Base.getindex(d::KarateClub) = d.graphs[1]
73-
Base.getindex(d::KarateClub, i) = getindex(d.graphs, i)
72+
Base.getindex(d::KarateClub, ::Colon) = d.graphs[1]
73+
Base.getindex(d::KarateClub, i) = d.graphs[i]

src/datasets/graphs/ogbdataset.jl

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -42,67 +42,48 @@ end
4242
The collection of datasets from the [Open Graph Benchmark: Datasets for Machine Learning on Graphs](https://arxiv.org/abs/2005.00687)
4343
paper.
4444
45-
`name` is the name of one of the dasets (listed [here](https://ogb.stanford.edu/docs/dataset_overview/))
45+
`name` is the name of one of the datasets (listed [here](https://ogb.stanford.edu/docs/dataset_overview/))
4646
available for node prediction, edge prediction, or graph prediction tasks.
4747
48-
The `OGBDataset` type stores the graphs internally as dictionary objects.
49-
The key "edge_index" contains `2 x num_edges`, where the first and second
50-
column contain the source and target nodes of each edge respectively.
51-
5248
# Examples
5349
5450
## Node prediction tasks
5551
5652
```julia-repl
57-
julia> data = OGBDataset("ogbn-arxiv")
58-
OGBDataset{Vector{Any}}:
59-
name => ogbn-arxiv
60-
path => /home/carlo/.julia/datadeps/OGBDataset/arxiv
61-
metadata => Dict{String, Any} with 15 entries
62-
graphs => 1-element Vector{Dict}
63-
labels => 1-element Vector{Any}
64-
split => Dict{String, Any} with 3 entries
65-
53+
julia> d = OGBDataset("ogbn-arxiv")
54+
dataset OGBDataset:
55+
name => ogbn-arxiv
56+
metadata => Dict{String, Any} with 16 entries
57+
graphs => 1-element Vector{MLDatasets.Graph}
58+
targets => nothing
59+
split_idx => (train = "90941-element Vector{Int64}", val = "29799-element Vector{Int64}", test = "48603-element Vector{Int64}")
60+
61+
julia> data[:]
62+
Graph:
63+
num_nodes => 169343
64+
num_edges => 1166243
65+
edge_index => ("1166243-element Vector{Int64}", "1166243-element Vector{Int64}")
66+
node_data => (year = "1×169343 Matrix{Int64}", features = "128×169343 Matrix{Float32}", label = "1×169343 Matrix{Int64}")
67+
edge_data => nothing
6668
6769
julia> data.metadata
68-
Dict{String, Any} with 15 entries:
70+
Dict{String, Any} with 16 entries:
71+
"download_name" => "arxiv"
6972
"num classes" => 40
73+
"num tasks" => 1
7074
"binary" => false
75+
"url" => "http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip"
76+
"additional node files" => "node_year"
7177
"is hetero" => false
78+
"path" => "/home/carlo/.julia/datadeps/OGBDataset/arxiv"
7279
"eval metric" => "acc"
7380
"task type" => "multiclass classification"
74-
"version" => 1
75-
"split" => "time"
76-
"download_name" => "arxiv"
77-
"num tasks" => 1
78-
"url" => "http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip"
79-
"additional node files" => "node_year"
8081
"add_inverse_edge" => false
8182
"has_node_attr" => true
8283
"additional edge files" => nothing
84+
"version" => 1
8385
"has_edge_attr" => false
84-
85-
julia> data.split
86-
Dict{String, Any} with 3 entries:
87-
"test_idx" => [347, 399, 452, 481, 489, 491, 527, 538, 541, 603 … 169334, 169335, 169336, 169337, 169338, 169339, 169340, 169341, 169342, 169343]
88-
"train_idx" => [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 … 169110, 169112, 169113, 169114, 169115, 169116, 169118, 169146, 169149, 169252]
89-
"val_idx" => [350, 358, 367, 383, 394, 422, 430, 436, 468, 470 … 169089, 169096, 169108, 169111, 169128, 169156, 169177, 169186, 169262, 169297]
90-
91-
julia> length(data)
92-
1
93-
94-
julia> graph, labels = data[1];
95-
96-
julia> graph
97-
Dict{String, Any} with 6 entries:
98-
"edge_index" => [104448 13092; 15859 47284; … ; 45119 162538; 45119 72718]
99-
"edge_feat" => nothing
100-
"node_feat" => Float32[-0.057943 -0.1245 … -0.138236 -0.029875; -0.05253 -0.070665 … 0.040885 0.268417; … ; -0.172796 -0.372111 … -0.041253 0.077647; -0.140059 -0.301036 … -0.376132 -0.091018]
101-
"num_nodes" => 169343
102-
"node_year" => [2013 2015 … 2020 2020]
103-
"num_edges" => 1166243
104-
105-
julia> source, target = graph["edge_index][:,1], graph["edge_index][:,2];
86+
"split" => "time"
10687
```
10788
10889
## Edge prediction task
@@ -356,15 +337,16 @@ end
356337

357338
function ogbdict2graph(d::Dict)
358339
edge_index = d["edge_index"][:,1], d["edge_index"][:,2]
359-
num_nodes, num_edges = d["num_nodes"], d["num_edges"]
340+
num_nodes = d["num_nodes"]
360341
node_data = Dict(Symbol(k[6:end]) => v for (k,v) in d if startswith(k, "node_") && v !== nothing)
361342
edge_data = Dict(Symbol(k[6:end]) => v for (k,v) in d if startswith(k, "edge_") && k!="edge_index" && v !== nothing)
362343
node_data = isempty(node_data) ? nothing : (; node_data...)
363344
edge_data = isempty(edge_data) ? nothing : (; edge_data...)
364-
return Graph(; num_nodes, num_edges,
365-
edge_index, node_data, edge_data)
345+
return Graph(; num_nodes, edge_index, node_data, edge_data)
366346
end
367347

368348
Base.length(data::OGBDataset) = length(data.graphs)
349+
Base.getindex(data::OGBDataset{Nothing}, ::Colon) = data.graphs
350+
Base.getindex(data::OGBDataset, ::Colon) = (; data.graphs, data.targets)
369351
Base.getindex(data::OGBDataset{Nothing}, i) = getobs(data.graphs, i)
370352
Base.getindex(data::OGBDataset, i) = getobs((; data.graphs, data.targets), i)

src/datasets/graphs/polblogs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,5 @@ function PolBlogs(; dir=nothing)
5151
end
5252

5353
Base.length(d::PolBlogs) = length(d.graphs)
54-
Base.getindex(d::PolBlogs) = d.graphs[1]
54+
Base.getindex(d::PolBlogs, ::Colon) = d.graphs[1]
5555
Base.getindex(d::PolBlogs, i) = getindex(d.graphs, i)

src/datasets/graphs/pubmed.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function PubMed(; dir=nothing, reverse_edges=true)
4242
end
4343

4444
Base.length(d::PubMed) = length(d.graphs)
45-
Base.getindex(d::PubMed) = d.graphs[1]
45+
Base.getindex(d::PubMed, ::Colon) = d.graphs[1]
4646
Base.getindex(d::PubMed, i) = getindex(d.graphs, i)
4747

4848

src/datasets/graphs/reddit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,5 @@ function Reddit(; full=true, dir=nothing)
122122
end
123123

124124
Base.length(d::Reddit) = length(d.graphs)
125-
Base.getindex(d::Reddit) = d.graphs
125+
Base.getindex(d::Reddit, ::Colon) = d.graphs
126126
Base.getindex(d::Reddit, i) = getindex(d.graphs, i)

0 commit comments

Comments
 (0)