Skip to content

Commit 42cc48f

Browse files
authored
Merge pull request #96 from darsnack/low-level-api
Redesign package to be built on top of reusable dataset containers
2 parents f45ef65 + 2ff0d29 commit 42cc48f

File tree

16 files changed

+538
-8
lines changed

16 files changed

+538
-8
lines changed

.github/workflows/UnitTest.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
julia-version: ['1.3', '1', 'nightly']
19+
julia-version: ['1.6', '1', 'nightly']
2020
os: [ubuntu-latest, windows-latest, macOS-latest, macos-11]
2121
env:
2222
PYTHON: ""
2323
steps:
24-
- uses: actions/checkout@v1.0.0
24+
- uses: actions/checkout@v2
2525
- name: "Set up Julia"
2626
uses: julia-actions/setup-julia@v1
2727
with:

Project.toml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,45 @@ version = "0.5.15"
44

55
[deps]
66
BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
7+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
78
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
89
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
10+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
911
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
12+
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
1013
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
1114
GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63"
15+
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
16+
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
17+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1218
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1319
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
20+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1421
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
1522
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1623
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
24+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1725

1826
[compat]
1927
BinDeps = "1"
28+
CSV = "0.10.2"
2029
ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11"
2130
DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7"
31+
DataFrames = "1.3"
32+
FileIO = "1.13"
2233
FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
2334
GZip = "0.5"
35+
Glob = "1.3"
36+
HDF5 = "0.16.2"
2437
ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
38+
JLD2 = "0.4.21"
2539
JSON3 = "1"
2640
MAT = "0.7, 0.8, 0.9, 0.10"
41+
MLUtils = "0.2.0"
2742
Pickle = "0.2, 0.3"
2843
Requires = "1"
29-
julia = "1.3"
44+
Tables = "1.6"
45+
julia = "1.6"
3046

3147
[extras]
3248
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ makedocs(
3737
"Mutagenesis" => "datasets/Mutagenesis.md",
3838
"Titanic" => "datasets/Titanic.md",
3939
],
40-
4140
"Text" => Any[
4241
"PTBLM" => "datasets/PTBLM.md",
4342
"UD_English" => "datasets/UD_English.md",
@@ -52,9 +51,11 @@ makedocs(
5251

5352
],
5453
"Utils" => "utils.md",
54+
"Data Containers" => "containers/overview.md",
5555
"LICENSE.md",
5656
],
57-
strict = true
57+
strict = true,
58+
checkdocs = :exports
5859
)
5960

6061

docs/src/containers/overview.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Dataset Containers
2+
3+
MLDatasets.jl contains several reusable data containers for accessing datasets in common storage formats. This feature is a work-in-progress and subject to change.
4+
5+
```@docs
6+
FileDataset
7+
TableDataset
8+
HDF5Dataset
9+
Base.close(::HDF5Dataset)
10+
JLD2Dataset
11+
Base.close(::JLD2Dataset)
12+
CachedDataset
13+
MLDatasets.make_cache
14+
```

src/MLDatasets.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ using DelimitedFiles: readdlm
77
using FixedPointNumbers, ColorTypes
88
using Pickle
99
using SparseArrays
10+
using FileIO
11+
using DataFrames, CSV, Tables
12+
using Glob
13+
using HDF5
14+
using JLD2
15+
16+
import MLUtils
17+
using MLUtils: getobs, numobs, AbstractDataContainer
18+
export getobs, numobs
1019

1120
# Julia 1.0 compatibility
1221
if !isdefined(Base, :isnothing)
@@ -36,6 +45,16 @@ end
3645

3746
include("download.jl")
3847

48+
include("containers/filedataset.jl")
49+
export FileDataset
50+
include("containers/tabledataset.jl")
51+
export TableDataset
52+
include("containers/hdf5dataset.jl")
53+
export HDF5Dataset
54+
include("containers/jld2dataset.jl")
55+
export JLD2Dataset
56+
include("containers/cacheddataset.jl")
57+
export CachedDataset
3958

4059
# Misc.
4160
include("BostonHousing/BostonHousing.jl")

src/containers/cacheddataset.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
make_cache(source, cacheidx)
3+
4+
Return a in-memory copy of `source` at observation indices `cacheidx`.
5+
Defaults to `getobs(source, cacheidx)`.
6+
"""
7+
make_cache(source, cacheidx) = getobs(source, cacheidx)
8+
9+
"""
10+
CachedDataset(source, cachesize = numbobs(source))
11+
CachedDataset(source, cacheidx = 1:numbobs(source))
12+
CachedDataset(source, cacheidx, cache)
13+
14+
Wrap a `source` data container and cache `cachesize` samples in memory.
15+
This can be useful for improving read speeds when `source` is a lazy data container,
16+
but your system memory is large enough to store a sizeable chunk of it.
17+
18+
By default the observation indices `1:cachesize` are cached.
19+
You can manually pass in a set of `cacheidx` as well.
20+
21+
See also [`make_cache`](@ref) for customizing the default cache creation for `source`.
22+
"""
23+
struct CachedDataset{T, S}
24+
source::T
25+
cacheidx::Vector{Int}
26+
cache::S
27+
end
28+
29+
CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source)) =
30+
CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx))
31+
CachedDataset(source, cachesize::Int = numobs(source)) = CachedDataset(source, 1:cachesize)
32+
33+
function Base.getindex(dataset::CachedDataset, i::Integer)
34+
_i = findfirst(==(i), dataset.cacheidx)
35+
36+
return isnothing(_i) ? getobs(dataset.source, i) : getobs(dataset.cache, _i)
37+
end
38+
Base.length(dataset::CachedDataset) = numobs(dataset.source)

src/containers/filedataset.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
rglob(filepattern, dir = pwd(), depth = 4)
3+
4+
Recursive glob up to `depth` layers deep within `dir`.
5+
"""
6+
function rglob(filepattern = "*", dir = pwd(), depth = 4)
7+
patterns = [repeat("*/", i) * filepattern for i in 0:(depth - 1)]
8+
9+
return vcat([glob(pattern, dir) for pattern in patterns]...)
10+
end
11+
12+
"""
13+
FileDataset([loadfn = FileIO.load,] paths)
14+
FileDataset([loadfn = FileIO.load,] dir, pattern = "*", depth = 4)
15+
16+
Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`).
17+
Alternatively, specify a `dir` and collect all paths that match a glob `pattern`
18+
(recursively globbing by `depth`). The glob order determines the traversal order.
19+
"""
20+
struct FileDataset{F, T<:AbstractString} <: AbstractDataContainer
21+
loadfn::F
22+
paths::Vector{T}
23+
end
24+
25+
FileDataset(paths) = FileDataset(FileIO.load, paths)
26+
FileDataset(loadfn,
27+
dir::AbstractString,
28+
pattern::AbstractString = "*",
29+
depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth))
30+
FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4) =
31+
FileDataset(FileIO.load, dir, pattern, depth)
32+
33+
Base.getindex(dataset::FileDataset, i::Integer) = dataset.loadfn(dataset.paths[i])
34+
Base.getindex(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is)
35+
Base.length(dataset::FileDataset) = length(dataset.paths)

src/containers/hdf5dataset.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
function _check_hdf5_shapes(shapes)
2+
nobs = map(last, filter(!isempty, shapes))
3+
4+
return all(==(first(nobs)), nobs[2:end])
5+
end
6+
7+
"""
8+
HDF5Dataset(file::AbstractString, paths)
9+
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}})
10+
HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}})
11+
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes)
12+
13+
Wrap several HDF5 datasets (`paths`) as a single dataset container.
14+
Each dataset `p` in `paths` should be accessible as `fid[p]`.
15+
Calling `getobs` on a `HDF5Dataset` returns a tuple with each element corresponding
16+
to the observation from each dataset in `paths`.
17+
See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer.
18+
19+
For array datasets, the last dimension is assumed to be the observation dimension.
20+
For scalar datasets, the stored value is returned by `getobs` for any index.
21+
"""
22+
struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer
23+
fid::HDF5.File
24+
paths::T
25+
shapes::Vector{Tuple}
26+
27+
function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}
28+
_check_hdf5_shapes(shapes) ||
29+
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations."))
30+
31+
new{T}(fid, paths, shapes)
32+
end
33+
end
34+
35+
HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)])
36+
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) =
37+
HDF5Dataset(fid, paths, map(size, paths))
38+
HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path])
39+
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) =
40+
HDF5Dataset(fid, map(p -> fid[p], paths))
41+
HDF5Dataset(file::AbstractString, paths) = HDF5Dataset(h5open(file, "r"), paths)
42+
43+
_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset)
44+
function _getobs_hdf5(dataset::HDF5.Dataset, shape, i)
45+
I = map(s -> 1:s, shape[1:(end - 1)])
46+
47+
return dataset[I..., i]
48+
end
49+
Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i) =
50+
_getobs_hdf5(dataset.paths, only(dataset.shapes), i)
51+
Base.getindex(dataset::HDF5Dataset{<:Vector}, i) =
52+
Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes))
53+
Base.length(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes)))
54+
55+
"""
56+
close(dataset::HDF5Dataset)
57+
58+
Close the underlying HDF5 file pointer for `dataset`.
59+
"""
60+
Base.close(dataset::HDF5Dataset) = close(dataset.fid)

src/containers/jld2dataset.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
_check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end])
2+
3+
"""
4+
JLD2Dataset(file::AbstractString, paths)
5+
JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}})
6+
7+
Wrap several JLD2 datasets (`paths`) as a single dataset container.
8+
Each dataset `p` in `paths` should be accessible as `fid[p]`.
9+
Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on
10+
each dataset in `paths`.
11+
See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer.
12+
"""
13+
struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer
14+
fid::T
15+
paths::S
16+
17+
function JLD2Dataset(fid::JLD2.JLDFile, paths)
18+
_paths = Tuple(map(p -> fid[p], paths))
19+
nobs = map(numobs, _paths)
20+
_check_jld2_nobs(nobs) ||
21+
throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations (got $nobs)."))
22+
23+
new{typeof(fid), typeof(_paths)}(fid, _paths)
24+
end
25+
end
26+
27+
JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,))
28+
JLD2Dataset(file::AbstractString, paths) = JLD2Dataset(jldopen(file, "r"), paths)
29+
30+
Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i)
31+
Base.getindex(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths)
32+
Base.length(dataset::JLD2Dataset) = numobs(dataset.paths[1])
33+
34+
"""
35+
close(dataset::JLD2Dataset)
36+
37+
Close the underlying JLD2 file pointer for `dataset`.
38+
"""
39+
Base.close(dataset::JLD2Dataset) = close(dataset.fid)

src/containers/tabledataset.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
TableDataset(table)
3+
TableDataset(path::AbstractString)
4+
5+
Wrap a Tables.jl-compatible `table` as a dataset container.
6+
Alternatively, specify the `path` to a CSV file directly
7+
to load it with CSV.jl + DataFrames.jl.
8+
"""
9+
struct TableDataset{T} <: AbstractDataContainer
10+
table::T
11+
12+
# TableDatasets must implement the Tables.jl interface
13+
function TableDataset{T}(table::T) where {T}
14+
Tables.istable(table) ||
15+
throw(ArgumentError("TableDatasets must implement the Tabels.jl interface"))
16+
17+
new{T}(table)
18+
end
19+
end
20+
21+
TableDataset(table::T) where {T} = TableDataset{T}(table)
22+
TableDataset(path::AbstractString) = TableDataset(DataFrame(CSV.File(path)))
23+
24+
# slow accesses based on Tables.jl
25+
_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1)))
26+
function _getobs_column(x, i)
27+
colnames = Tuple(Tables.columnnames(x))
28+
rowvals = ntuple(j -> Tables.getcolumn(x, j)[i], length(colnames))
29+
30+
return NamedTuple{colnames}(rowvals)
31+
end
32+
function Base.getindex(dataset::TableDataset, i)
33+
if Tables.rowaccess(dataset.table)
34+
return _getobs_row(Tables.rows(dataset.table), i)
35+
elseif Tables.columnaccess(dataset.table)
36+
return _getobs_column(dataset.table, i)
37+
else
38+
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
39+
end
40+
end
41+
function Base.length(dataset::TableDataset)
42+
if Tables.columnaccess(dataset.table)
43+
return length(Tables.getcolumn(dataset.table, 1))
44+
elseif Tables.rowaccess(dataset.table)
45+
# length might not be defined, but has to be for this to work.
46+
return length(Tables.rows(dataset.table))
47+
else
48+
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
49+
end
50+
end
51+
52+
# fast access for DataFrame
53+
Base.getindex(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :]
54+
Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table)
55+
56+
# fast access for CSV.File
57+
Base.getindex(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i]
58+
Base.length(dataset::TableDataset{<:CSV.File}) = length(dataset.table)
59+
60+
## Tables.jl interface
61+
62+
Tables.istable(::TableDataset) = true
63+
for fn in (:rowaccess, :rows, :columnaccess, :columns, :schema, :materializer)
64+
@eval Tables.$fn(dataset::TableDataset) = Tables.$fn(dataset.table)
65+
end

0 commit comments

Comments
 (0)