Skip to content

Commit d22eaac

Browse files
Merge pull request #124 from JuliaML/cl/docs4
change dataset[] to dataset[:]
2 parents d660236 + e4d0c24 commit d22eaac

31 files changed

+333
-245
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Its functionality is built on top of the package
1616
## Available Datasets
1717

1818
**Warning**: this package is under heavy redesign. The link belows point to the documentation for the yet to be released version.
19-
For the tagged version instead, please consult the [stable docs](https://JuliaML.github.io/MLDatasets.jl/stable)
19+
For the tagged version instead, please consult the [stable docs](https://JuliaML.github.io/MLDatasets.jl/stable).
2020

2121
Datasets are grouped into different categories. Click on the links below for a full list of datasets available in each category.
2222

docs/src/datasets/graphs.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Graphs Datasets
22

3+
A collection of datasets with an underlying graph structure.
4+
Some of these datasets contain a single graph, that can be accessed
5+
with `dataset[:]` or `dataset[1]`. Others contain many graphs,
6+
accessed through `dataset[i]`. Graphs are represented by the [`MLDatasets.Graph`](@ref) type.
7+
38
## Index
49

510
```@index

docs/src/index.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Where possible, those types share a common interface (fields and methods).
3838

3939
Once a dataset has been instantiated, e.g. by `dataset = MNIST()`,
4040
an observation `i` can be retrieved using the indexing syntax `dataset[i]`.
41-
By indexing with no arguments, `dataset[]`, the whole set of observations is collected.
41+
By indexing with no arguments, `dataset[:]`, the whole set of observations is collected.
4242
The total number of observations is given by `length(dataset)`.
4343

4444
For example you can load the training set of the [`MNIST`](@ref)
@@ -60,17 +60,17 @@ julia> trainset[1] # return first observation as a NamedTuple
6060
(features = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0],
6161
targets = 5)
6262
63-
julia> X_train, y_train = trainset[] # return all observations
63+
julia> X_train, y_train = trainset[:] # return all observations
6464
(features = [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0],
6565
targets = [5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8])
6666
6767
julia> summary(X_train)
6868
"28×28×60000 Array{Float32, 3}"
6969
```
7070

71-
Input features are commonly denoted by `features`, while classification labels or regression targets are denoted by `targets`.
71+
Input features are commonly denoted by `features`, while classification labels and regression targets are denoted by `targets`.
7272

73-
```julia
73+
```julia-repl
7474
julia> iris = Iris()
7575
dataset Iris:
7676
metadata => Dict{String, Any} with 4 entries

src/abstract_datasets.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
Super-type from which all datasets in MLDatasets.jl inherit.
55
66
Implements the following functionality:
7-
- `getobs(d)` and `getobs(d, i)` falling back to `d[]` and `d[i]`
7+
- `getobs(d)` and `getobs(d, i)` falling back to `d[:]` and `d[i]`
88
- Pretty printing.
99
"""
1010
abstract type AbstractDataset <: AbstractDataContainer end
1111

1212

13-
MLUtils.getobs(d::AbstractDataset) = d[]
13+
MLUtils.getobs(d::AbstractDataset) = d[:]
1414
MLUtils.getobs(d::AbstractDataset, i) = d[i]
1515

1616
function Base.show(io::IO, d::D) where D <: AbstractDataset
@@ -45,7 +45,8 @@ end
4545
_summary(x) = x
4646
_summary(x::Symbol) = ":$x"
4747
_summary(x::Union{Dict, AbstractArray, DataFrame}) = summary(x)
48-
_summary(x::Union{Tuple, NamedTuple}) = map(summary, x)
48+
_summary(x::Union{Tuple, NamedTuple}) = map(_summary, x)
49+
_summary(x::BitVector) = "$(count(x))-trues BitVector"
4950

5051
"""
5152
SupervisedDataset <: AbstractDataset
@@ -57,11 +58,10 @@ a `features` and a `targets` fields.
5758
abstract type SupervisedDataset <: AbstractDataset end
5859

5960

60-
6161
Base.length(d::SupervisedDataset) = numobs((d.features, d.targets))
6262

6363
# We return named tuples
64-
Base.getindex(d::SupervisedDataset) = getobs((; d.features, d.targets))
64+
Base.getindex(d::SupervisedDataset, ::Colon) = getobs((; d.features, d.targets))
6565
Base.getindex(d::SupervisedDataset, i) = getobs((; d.features, d.targets), i)
6666

6767
"""
@@ -75,7 +75,7 @@ abstract type UnsupervisedDataset <: AbstractDataset end
7575

7676
Base.length(d::UnsupervisedDataset) = numobs(d.features)
7777

78-
Base.getindex(d::UnsupervisedDataset) = getobs(d.features)
78+
Base.getindex(d::UnsupervisedDataset, ::Colon) = getobs(d.features)
7979
Base.getindex(d::UnsupervisedDataset, i) = getobs(d.features, i)
8080

8181

@@ -97,7 +97,7 @@ const FIELDS_SUPERVISED_TABLE = """
9797

9898
const METHODS_SUPERVISED_TABLE = """
9999
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets.
100-
- `dataset[]`: Return all observations as a named tuple of features and targets.
100+
- `dataset[:]`: Return all observations as a named tuple of features and targets.
101101
- `length(dataset)`: Number of observations.
102102
"""
103103

@@ -116,6 +116,6 @@ const FIELDS_SUPERVISED_ARRAY = """
116116

117117
const METHODS_SUPERVISED_ARRAY = """
118118
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets.
119-
- `dataset[]`: Return all observations as a named tuple of features and targets.
119+
- `dataset[:]`: Return all observations as a named tuple of features and targets.
120120
- `length(dataset)`: Number of observations.
121121
"""

src/datasets/graphs/citeseer.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ The dataset is retrieved from Ref. [2].
2727
2828
# References
2929
30-
[1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815)
30+
[1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815)
31+
3132
[2]: [Planetoid](https://github.com/kimiyoung/planetoid)
3233
"""
3334
struct CiteSeer <: AbstractDataset
@@ -41,9 +42,8 @@ function CiteSeer(; dir=nothing, reverse_edges=true)
4142
end
4243

4344
Base.length(d::CiteSeer) = length(d.graphs)
44-
Base.getindex(d::CiteSeer) = d.graphs[1]
45-
Base.getindex(d::CiteSeer, i) = getindex(d.graphs, i)
46-
45+
Base.getindex(d::CiteSeer, ::Colon) = d.graphs[1]
46+
Base.getindex(d::CiteSeer, i) = d.graphs[i]
4747

4848

4949
# DEPRECATED in v0.6.0

src/datasets/graphs/cora.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ doesn't consider all nodes.
4646
# References
4747
4848
[1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815)
49+
4950
[2]: [Planetoid](https://github.com/kimiyoung/planetoid
5051
"""
5152
struct Cora <: AbstractDataset
@@ -59,7 +60,7 @@ function Cora(; dir=nothing, reverse_edges=true)
5960
end
6061

6162
Base.length(d::Cora) = length(d.graphs)
62-
Base.getindex(d::Cora) = d.graphs[1]
63+
Base.getindex(d::Cora, ::Colon) = d.graphs[1]
6364
Base.getindex(d::Cora, i) = getindex(d.graphs, i)
6465

6566

src/datasets/graphs/karateclub.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
export KarateClub
22

33
"""
4-
Zachary's Karate Club
4+
KarateClub()
55
6-
The Karate Club dataset originally appeared in Ref [1].
6+
The Zachary's karate club dataset originally appeared in Ref [1].
77
88
The network contains 34 nodes (members of the karate club).
99
The nodes are connected by 78 undirected and unweighted edges.
@@ -18,8 +18,11 @@ One node per unique label is used as training data.
1818
# References
1919
2020
[1]: [An Information Flow Model for Conflict and Fission in Small Groups](http://www1.ind.ku.dk/complexLearning/zachary1977.pdf)
21+
2122
[2]: [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907)
23+
2224
[3]: [PyTorch Geometric Karate Club Dataset](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/karate.html#KarateClub)
25+
2326
[4]: [NetworkX Zachary's Karate Club Dataset](https://networkx.org/documentation/stable/_modules/networkx/generators/social.html#karate_club_graph)
2427
"""
2528
struct KarateClub
@@ -59,12 +62,12 @@ function KarateClub()
5962
0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0]
6063

6164
node_data = (; labels_clubs, labels_comm)
62-
g = Graph(; num_nodes=34, num_edges=156, edge_index=(src, target), node_data)
65+
g = Graph(; num_nodes=34, edge_index=(src, target), node_data)
6366

6467
metadata = Dict{String, Any}()
6568
return KarateClub(metadata, [g])
6669
end
6770

6871
Base.length(d::KarateClub) = length(d.graphs)
69-
Base.getindex(d::KarateClub) = d.graphs[1]
70-
Base.getindex(d::KarateClub, i) = getindex(d.graphs, i)
72+
Base.getindex(d::KarateClub, ::Colon) = d.graphs[1]
73+
Base.getindex(d::KarateClub, i) = d.graphs[i]

0 commit comments

Comments
 (0)