Skip to content

Commit 92e0d06

Browse files
committed
Add JLD2Dataset
1 parent 7bba084 commit 92e0d06

File tree

6 files changed

+72
-3
lines changed

6 files changed

+72
-3
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
1515
GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63"
1616
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
1717
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
18+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1819
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1920
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
2021
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"

src/MLDatasets.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ using FilePathsBase
1212
using FilePathsBase: AbstractPath
1313
using Glob
1414
using HDF5
15+
using JLD2
1516

1617
import MLUtils
1718
using MLUtils: getobs, numobs, AbstractDataContainer
1819

19-
export FileDataset, TableDataset, HDF5Dataset
20+
export FileDataset, TableDataset, HDF5Dataset, JLD2Dataset
2021
export getobs, numobs
2122

2223
# Julia 1.0 compatibility
@@ -50,6 +51,7 @@ include("download.jl")
5051
include("containers/filedataset.jl")
5152
include("containers/tabledataset.jl")
5253
include("containers/hdf5dataset.jl")
54+
include("containers/jld2dataset.jl")
5355

5456
# Misc.
5557
include("BostonHousing/BostonHousing.jl")

src/containers/hdf5dataset.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer.
1919
For array datasets, the last dimension is assumed to be the observation dimension.
2020
For scalar datasets, the stored value is returned by `getobs` for any index.
2121
"""
22-
struct HDF5Dataset
22+
struct HDF5Dataset <: AbstractDataContainer
2323
fid::HDF5.File
2424
paths::Vector{HDF5.Dataset}
2525
shapes::Vector{Tuple}
2626

2727
function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes::Vector)
2828
_check_hdf5_shapes(shapes) ||
29-
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatch number of observations."))
29+
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations."))
3030

3131
new(fid, paths, shapes)
3232
end

src/containers/jld2dataset.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
_check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end])
2+
3+
"""
4+
JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths)
5+
JLD2Dataset(fid::JLD2.JLDFile, paths)
6+
7+
Wrap several JLD2 datasets (`paths`) as a single dataset container.
8+
Each dataset `p` in `paths` should be accessible as `fid[p]`.
9+
Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on
10+
each dataset in `paths`.
11+
See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer.
12+
"""
13+
struct JLD2Dataset{T<:JLD2.JLDFile} <: AbstractDataContainer
14+
fid::T
15+
paths::Vector{String}
16+
17+
function JLD2Dataset(fid::JLD2.JLDFile, paths::Vector{String})
18+
nobs = map(p -> numobs(fid[p]), paths)
19+
_check_jld2_nobs(nobs) ||
20+
throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations."))
21+
22+
new{typeof(fid)}(fid, paths)
23+
end
24+
end
25+
26+
JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) =
27+
JLD2Dataset(jldopen(file, "r"), paths)
28+
29+
MLUtils.getobs(dataset::JLD2Dataset, i) = Tuple(map(p -> getobs(dataset.fid[p], i), dataset.paths))
30+
MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.fid[dataset.paths[1]])
31+
32+
"""
33+
close(dataset::JLD2Dataset)
34+
35+
Close the underlying JLD2 file pointer for `dataset`.
36+
"""
37+
Base.close(dataset::JLD2Dataset) = close(dataset.fid)

test/containers/jld2dataset.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
function setup_jld2dataset_test()
2+
datasets = [
3+
("d1", rand(2, 10)),
4+
("g1/d1", rand(10)),
5+
("g1/d2", string.('a':'j')),
6+
("g2/g1/d1", rand(Bool, 3, 3, 10))
7+
]
8+
9+
fid = jldopen("test.jld2", "w")
10+
for (path, data) in datasets
11+
fid[path] = data
12+
end
13+
close(fid)
14+
15+
return first.(datasets), Tuple(last.(datasets))
16+
end
17+
18+
@testset "JLD2Dataset" begin
19+
paths, datas = setup_jld2dataset_test()
20+
dataset = JLD2Dataset("test.jld2", paths)
21+
for i in 1:10
22+
@test getobs(dataset, i) == getobs(datas, i)
23+
end
24+
@test numobs(dataset) == 10
25+
close(dataset)
26+
rm("test.jld2")
27+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DataDeps
55
using MLUtils: getobs, numobs
66
using DataFrames, CSV, Tables
77
using HDF5
8+
using JLD2
89

910
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
1011

@@ -32,6 +33,7 @@ container_tests = [
3233
"containers/filedataset.jl",
3334
"containers/tabledataset.jl",
3435
"containers/hdf5dataset.jl",
36+
"containers/jld2dataset.jl",
3537
]
3638

3739
@testset "Datasets" for t in dataset_tests

0 commit comments

Comments
 (0)