Skip to content

Commit a30bfab

Browse files
committed
Add CachedDataset
1 parent 3d48893 commit a30bfab

File tree

8 files changed

+79
-7
lines changed

8 files changed

+79
-7
lines changed

src/MLDatasets.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ using JLD2
1616

1717
import MLUtils
1818
using MLUtils: getobs, numobs, AbstractDataContainer
19-
20-
export FileDataset, TableDataset, HDF5Dataset, JLD2Dataset
2119
export getobs, numobs
2220

2321
# Julia 1.0 compatibility
@@ -49,9 +47,15 @@ end
4947
include("download.jl")
5048

5149
include("containers/filedataset.jl")
50+
export FileDataset
5251
include("containers/tabledataset.jl")
52+
export TableDataset
5353
include("containers/hdf5dataset.jl")
54+
export HDF5Dataset
5455
include("containers/jld2dataset.jl")
56+
export JLD2Dataset
57+
include("containers/cacheddataset.jl")
58+
export CachedDataset
5559

5660
# Misc.
5761
include("BostonHousing/BostonHousing.jl")

src/containers/cacheddataset.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
make_cache(source, cacheidx)
3+
4+
Return a in-memory copy of `source` at observation indices `cacheidx`.
5+
Defaults to `getobs(source, cacheidx)`.
6+
"""
7+
make_cache(source, cacheidx) = getobs(source, cacheidx)
8+
9+
"""
10+
CachedDataset(source, cachesize = numbobs(source))
11+
CachedDataset(source, cacheidx, cache)
12+
13+
Wrap a `source` data container and cache `cachesize` samples in memory.
14+
This can be useful for improving read speeds when `source` is a lazy data container,
15+
but your system memory is large enough to store a sizeable chunk of it.
16+
17+
By default the observation indices `1:cachesize` are cached.
18+
You can manually pass in a `cache` and set of `cacheidx` as well.
19+
20+
See also [`make_cache`](@ref) for customizing the default cache creation for `source`.
21+
"""
22+
struct CachedDataset{T, S}
23+
source::T
24+
cacheidx::Vector{Int}
25+
cache::S
26+
end
27+
28+
function CachedDataset(source, cachesize::Int = numobs(source))
29+
cacheidx = 1:cachesize
30+
31+
CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx))
32+
end
33+
34+
function MLUtils.getobs(dataset::CachedDataset, i::Integer)
35+
_i = findfirst(==(i), dataset.cacheidx)
36+
37+
return isnothing(_i) ? getobs(dataset.source, i) : getobs(dataset.cache, _i)
38+
end
39+
MLUtils.numobs(dataset::CachedDataset) = numobs(dataset.source)

src/containers/filedataset.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,6 @@ FileDataset(loadfn,
5252
FileDataset(dir::Union{AbstractPath, AbstractString}, pattern::AbstractString = "*", depth = 4) =
5353
FileDataset(loadfile, dir, pattern, depth)
5454

55-
MLUtils.getobs(dataset::FileDataset, i) = loadfile(dataset.paths[i])
55+
MLUtils.getobs(dataset::FileDataset, i::Integer) = loadfile(dataset.paths[i])
56+
MLUtils.getobs(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is)
5657
MLUtils.numobs(dataset::FileDataset) = length(dataset.paths)

test/containers/cacheddataset.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
@testset "CachedDataset" begin
2+
@testset "CachedDataset(::FileDataset)" begin
3+
files = setup_filedataset_test()
4+
fdataset = FileDataset("root", "*.csv")
5+
cdataset = CachedDataset(fdataset)
6+
7+
@test numobs(cdataset) == numobs(fdataset)
8+
@test cdataset.cache == getobs(fdataset, 1:numobs(fdataset))
9+
@test all(getobs(cdataset, i) == getobs(fdataset, i) for i in 1:numobs(fdataset))
10+
11+
cleanup_filedataset_test()
12+
end
13+
14+
@testset "CachedDataset(::HDF5Dataset)" begin
15+
paths, datas = setup_hdf5dataset_test()
16+
hdataset = HDF5Dataset("test.h5", ["d1"])
17+
cdataset = CachedDataset(hdataset, 5)
18+
19+
@test numobs(cdataset) == numobs(hdataset)
20+
@test cdataset.cache == getobs(hdataset, 1:5)
21+
@test all(getobs(cdataset, i) == getobs(hdataset, i) for i in 1:10)
22+
23+
cleanup_hdf5dataset_test()
24+
end
25+
end

test/containers/filedataset.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ function setup_filedataset_test()
2323

2424
return files
2525
end
26+
cleanup_filedataset_test() = rm("root"; recursive = true)
2627

2728
@testset "FileDataset" begin
2829
files = setup_filedataset_test()
@@ -32,5 +33,5 @@ end
3233
true_obs = MLDatasets.loadfile(file)
3334
@test getobs(dataset, i) == true_obs
3435
end
35-
rm("root"; recursive = true)
36+
cleanup_filedataset_test()
3637
end

test/containers/hdf5dataset.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ function setup_hdf5dataset_test()
1414

1515
return first.(datasets), last.(datasets)
1616
end
17+
cleanup_hdf5dataset_test() = rm("test.h5")
1718

1819
@testset "HDF5Dataset" begin
1920
paths, datas = setup_hdf5dataset_test()
@@ -24,5 +25,5 @@ end
2425
end
2526
@test numobs(dataset) == 10
2627
close(dataset)
27-
rm("test.h5")
28+
cleanup_hdf5dataset_test()
2829
end

test/containers/jld2dataset.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ function setup_jld2dataset_test()
1414

1515
return first.(datasets), Tuple(last.(datasets))
1616
end
17+
cleanup_jld2dataset_test() = rm("test.jld2")
1718

1819
@testset "JLD2Dataset" begin
1920
paths, datas = setup_jld2dataset_test()
@@ -23,5 +24,5 @@ end
2324
end
2425
@test numobs(dataset) == 10
2526
close(dataset)
26-
rm("test.jld2")
27+
cleanup_jld2dataset_test()
2728
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Test
22
using MLDatasets
33
using ImageCore
44
using DataDeps
5-
using MLUtils: getobs, numobs
65
using DataFrames, CSV, Tables
76
using HDF5
87
using JLD2
@@ -34,6 +33,7 @@ container_tests = [
3433
"containers/tabledataset.jl",
3534
"containers/hdf5dataset.jl",
3635
"containers/jld2dataset.jl",
36+
"containers/cacheddataset.jl",
3737
]
3838

3939
@testset "Datasets" for t in dataset_tests

0 commit comments

Comments
 (0)