Skip to content

Commit b7ba9c4

Browse files
committed
Special case single path HDF5 and JLD2 datasets and add @inferred tests
1 parent a30bfab commit b7ba9c4

File tree

7 files changed

+92
-45
lines changed

7 files changed

+92
-45
lines changed

src/containers/hdf5dataset.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ end
66

77
"""
88
HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths)
9-
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset})
10-
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString})
11-
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes)
9+
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}})
10+
HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}})
11+
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes)
1212
1313
Wrap several HDF5 datasets (`paths`) as a single dataset container.
1414
Each dataset `p` in `paths` should be accessible as `fid[p]`.
@@ -19,34 +19,38 @@ See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer.
1919
For array datasets, the last dimension is assumed to be the observation dimension.
2020
For scalar datasets, the stored value is returned by `getobs` for any index.
2121
"""
22-
struct HDF5Dataset <: AbstractDataContainer
22+
struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer
2323
fid::HDF5.File
24-
paths::Vector{HDF5.Dataset}
24+
paths::T
2525
shapes::Vector{Tuple}
2626

27-
function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes::Vector)
27+
function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}
2828
_check_hdf5_shapes(shapes) ||
2929
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations."))
3030

31-
new(fid, paths, shapes)
31+
new{T}(fid, paths, shapes)
3232
end
3333
end
3434

35+
HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)])
3536
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) =
3637
HDF5Dataset(fid, paths, map(size, paths))
38+
HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path])
3739
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) =
3840
HDF5Dataset(fid, map(p -> fid[p], paths))
3941
HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) =
4042
HDF5Dataset(h5open(file, "r"), paths)
4143

42-
MLUtils.getobs(dataset::HDF5Dataset, i) = Tuple(map(dataset.paths, dataset.shapes) do path, shape
43-
if isempty(shape)
44-
return read(path)
45-
else
46-
I = map(s -> 1:s, shape[1:(end - 1)])
47-
return path[I..., i]
48-
end
49-
end)
44+
_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset)
45+
function _getobs_hdf5(dataset::HDF5.Dataset, shape, i)
46+
I = map(s -> 1:s, shape[1:(end - 1)])
47+
48+
return dataset[I..., i]
49+
end
50+
MLUtils.getobs(dataset::HDF5Dataset{HDF5.Dataset}, i) =
51+
_getobs_hdf5(dataset.paths, only(dataset.shapes), i)
52+
MLUtils.getobs(dataset::HDF5Dataset{<:Vector}, i) =
53+
Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes))
5054
MLUtils.numobs(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes)))
5155

5256
"""

src/containers/jld2dataset.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,35 @@ _check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end])
22

33
"""
44
JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths)
5-
JLD2Dataset(fid::JLD2.JLDFile, paths)
5+
JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}})
66
77
Wrap several JLD2 datasets (`paths`) as a single dataset container.
88
Each dataset `p` in `paths` should be accessible as `fid[p]`.
99
Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on
1010
each dataset in `paths`.
1111
See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer.
1212
"""
13-
struct JLD2Dataset{T<:JLD2.JLDFile} <: AbstractDataContainer
13+
struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer
1414
fid::T
15-
paths::Vector{String}
15+
paths::S
1616

17-
function JLD2Dataset(fid::JLD2.JLDFile, paths::Vector{String})
18-
nobs = map(p -> numobs(fid[p]), paths)
17+
function JLD2Dataset(fid::JLD2.JLDFile, paths)
18+
_paths = Tuple(map(p -> fid[p], paths))
19+
nobs = map(numobs, _paths)
1920
_check_jld2_nobs(nobs) ||
20-
throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations."))
21+
throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations (got $nobs)."))
2122

22-
new{typeof(fid)}(fid, paths)
23+
new{typeof(fid), typeof(_paths)}(fid, _paths)
2324
end
2425
end
2526

27+
JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,))
2628
JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) =
2729
JLD2Dataset(jldopen(file, "r"), paths)
2830

29-
MLUtils.getobs(dataset::JLD2Dataset, i) = Tuple(map(p -> getobs(dataset.fid[p], i), dataset.paths))
30-
MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.fid[dataset.paths[1]])
31+
MLUtils.getobs(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i)
32+
MLUtils.getobs(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths)
33+
MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.paths[1])
3134

3235
"""
3336
close(dataset::JLD2Dataset)

src/containers/tabledataset.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@ TableDataset(path::Union{AbstractPath, AbstractString}) =
2323
TableDataset(DataFrame(CSV.File(path)))
2424

2525
# slow accesses based on Tables.jl
26+
_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1)))
27+
function _getobs_column(x, i)
28+
colnames = Tuple(Tables.columnnames(x))
29+
rowvals = ntuple(j -> Tables.getcolumn(x, j)[i], length(colnames))
30+
31+
return NamedTuple{colnames}(rowvals)
32+
end
2633
function MLUtils.getobs(dataset::TableDataset, i)
2734
if Tables.rowaccess(dataset.table)
28-
row, _ = Iterators.peel(Iterators.drop(Tables.rows(dataset.table), i - 1))
29-
return row
35+
return _getobs_row(Tables.rows(dataset.table), i)
3036
elseif Tables.columnaccess(dataset.table)
31-
colnames = Tables.columnnames(dataset.table)
32-
rowvals = [Tables.getcolumn(dataset.table, j)[i] for j in 1:length(colnames)]
33-
return (; zip(colnames, rowvals)...)
37+
return _getobs_column(dataset.table, i)
3438
else
3539
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
3640
end

test/containers/cacheddataset.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,29 @@
1313

1414
@testset "CachedDataset(::HDF5Dataset)" begin
1515
paths, datas = setup_hdf5dataset_test()
16-
hdataset = HDF5Dataset("test.h5", ["d1"])
16+
hdataset = HDF5Dataset("test.h5", "d1")
1717
cdataset = CachedDataset(hdataset, 5)
1818

1919
@test numobs(cdataset) == numobs(hdataset)
20+
@test cdataset.cache isa Matrix{Float64}
2021
@test cdataset.cache == getobs(hdataset, 1:5)
2122
@test all(getobs(cdataset, i) == getobs(hdataset, i) for i in 1:10)
2223

24+
close(hdataset)
2325
cleanup_hdf5dataset_test()
2426
end
27+
28+
@testset "CachedDataset(::JLD2Dataset)" begin
29+
paths, datas = setup_jld2dataset_test()
30+
jdataset = JLD2Dataset("test.jld2", "d1")
31+
cdataset = CachedDataset(jdataset, 5)
32+
33+
@test numobs(cdataset) == numobs(jdataset)
34+
@test cdataset.cache isa Matrix{Float64}
35+
@test cdataset.cache == getobs(jdataset, 1:5)
36+
@test all(@inferred(getobs(cdataset, i)) == getobs(jdataset, i) for i in 1:10)
37+
38+
close(jdataset)
39+
cleanup_jld2dataset_test()
40+
end
2541
end

test/containers/hdf5dataset.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,22 @@ cleanup_hdf5dataset_test() = rm("test.h5")
1818

1919
@testset "HDF5Dataset" begin
2020
paths, datas = setup_hdf5dataset_test()
21-
dataset = HDF5Dataset("test.h5", paths)
22-
for i in 1:10
23-
data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas))
24-
@test getobs(dataset, i) == data
21+
@testset "Single path" begin
22+
dataset = HDF5Dataset("test.h5", "d1")
23+
for i in 1:10
24+
@test getobs(dataset, i) == getobs(datas[1], i)
25+
end
26+
@test numobs(dataset) == 10
27+
close(dataset)
28+
end
29+
@testset "Multiple paths" begin
30+
dataset = HDF5Dataset("test.h5", paths)
31+
for i in 1:10
32+
data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas))
33+
@test @inferred(Tuple, getobs(dataset, i)) == data
34+
end
35+
@test numobs(dataset) == 10
36+
close(dataset)
2537
end
26-
@test numobs(dataset) == 10
27-
close(dataset)
2838
cleanup_hdf5dataset_test()
2939
end

test/containers/jld2dataset.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,21 @@ cleanup_jld2dataset_test() = rm("test.jld2")
1818

1919
@testset "JLD2Dataset" begin
2020
paths, datas = setup_jld2dataset_test()
21-
dataset = JLD2Dataset("test.jld2", paths)
22-
for i in 1:10
23-
@test getobs(dataset, i) == getobs(datas, i)
21+
@testset "Single path" begin
22+
dataset = JLD2Dataset("test.jld2", "d1")
23+
for i in 1:10
24+
@test @inferred(getobs(dataset, i)) == getobs(datas[1], i)
25+
end
26+
@test numobs(dataset) == 10
27+
close(dataset)
28+
end
29+
@testset "Multiple paths" begin
30+
dataset = JLD2Dataset("test.jld2", paths)
31+
for i in 1:10
32+
@test @inferred(getobs(dataset, i)) == getobs(datas, i)
33+
end
34+
@test numobs(dataset) == 10
35+
close(dataset)
2436
end
25-
@test numobs(dataset) == 10
26-
close(dataset)
2737
cleanup_jld2dataset_test()
2838
end

test/containers/tabledataset.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
77
td = TableDataset(testtable)
88

9-
@test all(getobs(td, 1) .== [1, 4.0, "7"])
9+
@test collect(@inferred(getobs(td, 1))) == [1, 4.0, "7"]
1010
@test numobs(td) == 3
1111
end
1212

@@ -17,7 +17,7 @@
1717
testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
1818
td = TableDataset(testtable)
1919

20-
@test [data for data in getobs(td, 2)] == [2, 5.0, "8"]
20+
@test collect(@inferred(NamedTuple, getobs(td, 2))) == [2, 5.0, "8"]
2121
@test numobs(td) == 3
2222

2323
@test getobs(td, 1) isa NamedTuple
@@ -35,7 +35,7 @@
3535
td = TableDataset(testtable)
3636
@test td isa TableDataset{<:DataFrame}
3737

38-
@test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"]
38+
@test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"]
3939
@test numobs(td) == 5
4040
end
4141

@@ -46,7 +46,7 @@
4646
testtable = CSV.File("test.csv")
4747
td = TableDataset(testtable)
4848
@test td isa TableDataset{<:CSV.File}
49-
@test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"]
49+
@test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"]
5050
@test numobs(td) == 1
5151
rm("test.csv")
5252
end

0 commit comments

Comments
 (0)