Skip to content

Commit 834d9b8

Browse files
committed
Initial port of FastAI dataset containers
1 parent f45ef65 commit 834d9b8

File tree

6 files changed

+167
-3
lines changed

6 files changed

+167
-3
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@ version = "0.5.15"
44

55
[deps]
66
BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
7+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
78
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
89
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
10+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
911
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
12+
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
13+
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
1014
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
1115
GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63"
1216
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1317
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
18+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1419
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
1520
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1621
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
22+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1723

1824
[compat]
1925
BinDeps = "1"

src/MLDatasets.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ using DelimitedFiles: readdlm
77
using FixedPointNumbers, ColorTypes
88
using Pickle
99
using SparseArrays
10+
using DataFrames, CSV, Tables
11+
using FilePathsBase
12+
using FilePathsBase: AbstractPath
13+
14+
import MLUtils
15+
using MLUtils: getobs, numobs, AbstractDataContainer
16+
17+
export FileDataset, TableDataset
18+
export getobs, numobs
1019

1120
# Julia 1.0 compatibility
1221
if !isdefined(Base, :isnothing)
@@ -36,6 +45,8 @@ end
3645

3746
include("download.jl")
3847

48+
include("containers/filedataset.jl")
49+
include("containers/tabledataset.jl")
3950

4051
# Misc.
4152
include("BostonHousing/BostonHousing.jl")

src/containers/filedataset.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# FileDataset
2+
3+
"""
4+
rglob(filepattern, dir = pwd(), depth = 4)
5+
6+
Recursive glob up to `depth` layers deep within `dir`.
7+
"""
8+
function rglob(filepattern = "*", dir = pwd(), depth = 4)
9+
patterns = [repeat("*/", i) * filepattern for i in 0:(depth - 1)]
10+
11+
return vcat([glob(pattern, dir) for pattern in patterns]...)
12+
end
13+
14+
"""
15+
loadfile(file)
16+
17+
Load a file from disk into the appropriate format.
18+
"""
19+
function loadfile(file::String)
20+
if isimagefile(file)
21+
# faster image loading
22+
return FileIO.load(file, view = true)
23+
elseif endswith(file, ".csv")
24+
return DataFrame(CSV.File(file))
25+
else
26+
return FileIO.load(file)
27+
end
28+
end
29+
loadfile(file::AbstractPath) = loadfile(string(file))
30+
31+
struct FileDataset{T} <: AbstractDataContainer
32+
paths::T
33+
end
34+
35+
FileDataset(dir, pattern = "*", depth = 4) = rglob(pattern, string(dir), depth)
36+
37+
MLUtils.getobs(dataset::FileDataset, i) = loadfile(dataset.paths[i])
38+
MLUtils.numobs(dataset::FileDataset) = length(dataset.paths)

src/containers/tabledataset.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
struct TableDataset{T} <: AbstractDataContainer
2+
table::T
3+
4+
# TableDatasets must implement the Tables.jl interface
5+
function TableDataset{T}(table::T) where {T}
6+
Tables.istable(table) ||
7+
throw(ArgumentError("TableDatasets must implement the Tabels.jl interface"))
8+
9+
new{T}(table)
10+
end
11+
end
12+
13+
TableDataset(table::T) where {T} = TableDataset{T}(table)
14+
TableDataset(path::Union{AbstractPath, AbstractString}) =
15+
TableDataset(DataFrame(CSV.File(path)))
16+
17+
# slow accesses based on Tables.jl
18+
function MLUtils.getobs(dataset::TableDataset, i)
19+
if Tables.rowaccess(dataset.table)
20+
row, _ = Iterators.peel(Iterators.drop(Tables.rows(dataset.table), i - 1))
21+
return row
22+
elseif Tables.columnaccess(dataset.table)
23+
colnames = Tables.columnnames(dataset.table)
24+
rowvals = [Tables.getcolumn(dataset.table, j)[i] for j in 1:length(colnames)]
25+
return (; zip(colnames, rowvals)...)
26+
else
27+
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
28+
end
29+
end
30+
function MLUtils.numobs(dataset::TableDataset)
31+
if Tables.columnaccess(dataset.table)
32+
return length(Tables.getcolumn(dataset.table, 1))
33+
elseif Tables.rowaccess(dataset.table)
34+
# length might not be defined, but has to be for this to work.
35+
return length(Tables.rows(dataset.table))
36+
else
37+
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
38+
end
39+
end
40+
41+
# fast access for DataFrame
42+
MLUtils.getobs(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :]
43+
MLUtils.numobs(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table)
44+
45+
# fast access for CSV.File
46+
MLUtils.getobs(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i]
47+
MLUtils.numobs(dataset::TableDataset{<:CSV.File}) = length(dataset.table)

test/containers/tabledataset.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
@testset "TableDataset" begin
2+
@testset "TableDataset from rowaccess table" begin
3+
Tables.columnaccess(::Type{<:Tables.MatrixTable}) = false
4+
Tables.rowaccess(::Type{<:Tables.MatrixTable}) = true
5+
6+
testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
7+
td = TableDataset(testtable)
8+
9+
@test all(getobs(td, 1) .== [1, 4.0, "7"])
10+
@test numobs(td) == 3
11+
end
12+
13+
@testset "TableDataset from columnaccess table" begin
14+
Tables.columnaccess(::Type{<:Tables.MatrixTable}) = true
15+
Tables.rowaccess(::Type{<:Tables.MatrixTable}) = false
16+
17+
testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
18+
td = TableDataset(testtable)
19+
20+
@test [data for data in getobs(td, 2)] == [2, 5.0, "8"]
21+
@test numobs(td) == 3
22+
23+
@test getobs(td, 1) isa NamedTuple
24+
end
25+
26+
@testset "TableDataset from DataFrames" begin
27+
testtable = DataFrame(
28+
col1 = [1, 2, 3, 4, 5],
29+
col2 = ["a", "b", "c", "d", "e"],
30+
col3 = [10, 20, 30, 40, 50],
31+
col4 = ["A", "B", "C", "D", "E"],
32+
col5 = [100.0, 200.0, 300.0, 400.0, 500.0],
33+
split = ["train", "train", "train", "valid", "valid"],
34+
)
35+
td = TableDataset(testtable)
36+
@test td isa TableDataset{<:DataFrame}
37+
38+
@test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"]
39+
@test numobs(td) == 5
40+
end
41+
42+
@testset "TableDataset from CSV" begin
43+
open("test.csv", "w") do io
44+
write(io, "col1,col2,col3,col4,col5, split\n1,a,10,A,100.,train")
45+
end
46+
testtable = CSV.File("test.csv")
47+
td = TableDataset(testtable)
48+
@test td isa TableDataset{<:CSV.File}
49+
@test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"]
50+
@test numobs(td) == 1
51+
rm("test.csv")
52+
end
53+
end

test/runtests.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ using Test
22
using MLDatasets
33
using ImageCore
44
using DataDeps
5-
5+
using MLUtils: getobs, numobs
6+
using DataFrames, CSV, Tables
67

78
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
89

9-
tests = [
10+
dataset_tests = [
1011
# misc
1112
"tst_iris.jl",
1213
"tst_boston_housing.jl",
@@ -26,12 +27,20 @@ tests = [
2627
"tst_tudataset.jl",
2728
]
2829

29-
for t in tests
30+
container_tests = [
31+
"containers/tabledataset.jl",
32+
]
33+
34+
@testset "Datasets" for t in dataset_tests
3035
@testset "$t" begin
3136
include(t)
3237
end
3338
end
3439

40+
@testset "Containers" for t in container_tests
41+
include(t)
42+
end
43+
3544
#temporary to not stress CI
3645
if !parse(Bool, get(ENV, "CI", "false"))
3746
@testset "other tests" begin

0 commit comments

Comments
 (0)