Skip to content

Commit 757b706

Browse files
committed
Add some docstrings and test for FileDataset
1 parent 83452ed commit 757b706

File tree

6 files changed

+66
-5
lines changed

6 files changed

+66
-5
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
1313
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
1414
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
1515
GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63"
16+
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
1617
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1718
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
1819
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
@@ -23,11 +24,12 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2324

2425
[compat]
2526
BinDeps = "1"
26-
ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11"
2727
CSV = "0.10.2"
28+
ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11"
2829
DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7"
2930
DataFrames = "1.3"
3031
FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
32+
Glob = "1.3"
3133
GZip = "0.5"
3234
ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
3335
JSON3 = "1"

src/MLDatasets.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using SparseArrays
1010
using DataFrames, CSV, Tables
1111
using FilePathsBase
1212
using FilePathsBase: AbstractPath
13+
using Glob
1314

1415
import MLUtils
1516
using MLUtils: getobs, numobs, AbstractDataContainer

src/containers/filedataset.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
# FileDataset
1+
matches(re::Regex) = f -> matches(re, f)
2+
matches(re::Regex, f) = !isnothing(match(re, f))
3+
const RE_IMAGEFILE = r".*\.(gif|jpe?g|tiff?|png|webp|bmp)$"i
4+
isimagefile(f) = matches(RE_IMAGEFILE, f)
25

36
"""
47
rglob(filepattern, dir = pwd(), depth = 4)
@@ -28,11 +31,19 @@ function loadfile(file::String)
2831
end
2932
loadfile(file::AbstractPath) = loadfile(string(file))
3033

34+
"""
35+
FileDataset(paths)
36+
FileDataset(dir, pattern = "*", depth = 4)
37+
38+
Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`).
39+
Alternatively, specify a `dir` and collect all paths that match a glob `pattern`
40+
(recursively globbing by `depth`). The glob order determines the traversal order.
41+
"""
3142
struct FileDataset{T} <: AbstractDataContainer
3243
paths::T
3344
end
3445

35-
FileDataset(dir, pattern = "*", depth = 4) = rglob(pattern, string(dir), depth)
46+
FileDataset(dir, pattern = "*", depth = 4) = FileDataset(rglob(pattern, string(dir), depth))
3647

3748
MLUtils.getobs(dataset::FileDataset, i) = loadfile(dataset.paths[i])
3849
MLUtils.numobs(dataset::FileDataset) = length(dataset.paths)

src/containers/tabledataset.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
TableDataset(table)
3+
TableDataset(path::Union{AbstractPath, AbstractString})
4+
5+
Wrap a Tables.jl-compatible `table` as a dataset container.
6+
Alternatively, specify the `path` to a CSV file directly
7+
to load it with CSV.jl + DataFrames.jl.
8+
"""
19
struct TableDataset{T} <: AbstractDataContainer
210
table::T
311

test/containers/filedataset.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
function setup_filedataset_test()
2+
files = [
3+
"root/p1/f1.csv",
4+
"root/p2/f2.csv",
5+
"root/p2/p2p1/f2.csv",
6+
"root/p3/p3p1/f1.csv"
7+
]
8+
9+
for (i, file) in enumerate(files)
10+
paths = splitpath(file)[1:(end - 1)]
11+
root = ""
12+
for p in paths
13+
fullp = joinpath(root, p)
14+
isdir(fullp) || mkdir(fullp)
15+
root = fullp
16+
end
17+
18+
open(file, "w") do io
19+
write(io, "a,b,c\n")
20+
write(io, join(i .* [1, 2, 3], ","))
21+
end
22+
end
23+
24+
return files
25+
end
26+
27+
@testset "FileDataset" begin
28+
files = setup_filedataset_test()
29+
dataset = FileDataset("root", "*.csv")
30+
@test numobs(dataset) == length(files)
31+
for (i, file) in enumerate(files)
32+
true_obs = MLDatasets.loadfile(file)
33+
@test getobs(dataset, i) == true_obs
34+
end
35+
rm("root"; recursive = true)
36+
end

test/runtests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dataset_tests = [
2828
]
2929

3030
container_tests = [
31+
"containers/filedataset.jl",
3132
"containers/tabledataset.jl",
3233
]
3334

@@ -37,8 +38,10 @@ container_tests = [
3738
end
3839
end
3940

40-
@testset "Containers" for t in container_tests
41-
include(t)
41+
@testset "Containers" begin
42+
for t in container_tests
43+
include(t)
44+
end
4245
end
4346

4447
#temporary to not stress CI

0 commit comments

Comments
 (0)