Skip to content

Commit 3d48893

Browse files
committed
Support custom loading function in FileDataset
1 parent 92e0d06 commit 3d48893

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/containers/filedataset.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,25 @@ end
3232
loadfile(file::AbstractPath) = loadfile(string(file))
3333

3434
"""
35-
FileDataset(paths)
36-
FileDataset(dir, pattern = "*", depth = 4)
35+
FileDataset([loadfn = loadfile,] paths)
36+
FileDataset([loadfn = loadfile,] dir, pattern = "*", depth = 4)
3737
3838
Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`).
3939
Alternatively, specify a `dir` and collect all paths that match a glob `pattern`
4040
(recursively globbing by `depth`). The glob order determines the traversal order.
4141
"""
42-
struct FileDataset{T<:Union{AbstractPath, AbstractString}} <: AbstractDataContainer
42+
struct FileDataset{F, T<:Union{AbstractPath, AbstractString}} <: AbstractDataContainer
43+
loadfn::F
4344
paths::Vector{T}
4445
end
4546

46-
FileDataset(dir, pattern = "*", depth = 4) = FileDataset(rglob(pattern, string(dir), depth))
47+
FileDataset(paths) = FileDataset(loadfile, paths)
48+
FileDataset(loadfn,
49+
dir::Union{AbstractPath, AbstractString},
50+
pattern::AbstractString = "*",
51+
depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth))
52+
FileDataset(dir::Union{AbstractPath, AbstractString}, pattern::AbstractString = "*", depth = 4) =
53+
FileDataset(loadfile, dir, pattern, depth)
4754

4855
MLUtils.getobs(dataset::FileDataset, i) = loadfile(dataset.paths[i])
4956
MLUtils.numobs(dataset::FileDataset) = length(dataset.paths)

0 commit comments

Comments
 (0)