Skip to content

Commit 34ae3e1

Browse files
add FashionMNIST dataset
1 parent 3bbc0ae commit 34ae3e1

File tree

15 files changed

+946
-9
lines changed

15 files changed

+946
-9
lines changed

src/CIFAR10.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module CIFAR10
33

44
using BinDeps
55

6-
const defdir = joinpath(Pkg.dir("MLDatasets"), "datasets/cifar10")
6+
const defdir = joinpath(Pkg.dir("MLDatasets"), "datasets", "cifar10")
77

88
function getdata(dir)
99
mkpath(dir)
@@ -25,7 +25,7 @@ function readdata(data::Vector{UInt8})
2525
end
2626

2727
function traindata(dir=defdir)
28-
files = ["$(dir)/cifar-10-batches-bin/data_batch_$(i).bin" for i=1:5]
28+
files = [joinpath(dir,"cifar-10-batches-bin","data_batch_$i.bin") for i=1:5]
2929
all(isfile, files) || getdata(dir)
3030
data = UInt8[]
3131
for file in files
@@ -35,7 +35,7 @@ function traindata(dir=defdir)
3535
end
3636

3737
function testdata(dir=defdir)
38-
file = "$(dir)/cifar-10-batches-bin/test_batch.bin"
38+
file = joinpath(dir,"cifar-10-batches-bin","test_batch.bin")
3939
isfile(file) || getdata(dir)
4040
readdata(open(read,file))
4141
end

src/CIFAR100.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module CIFAR100
33

44
using BinDeps
55

6-
const defdir = joinpath(Pkg.dir("MLDatasets"), "datasets/cifar100")
6+
const defdir = joinpath(Pkg.dir("MLDatasets"), "datasets","cifar100")
77

88
function getdata(dir)
99
mkpath(dir)
@@ -25,13 +25,13 @@ function readdata(data::Vector{UInt8})
2525
end
2626

2727
function traindata(dir=defdir)
28-
file = joinpath(dir, "cifar-100-binary/train.bin")
28+
file = joinpath(dir, "cifar-100-binary","train.bin")
2929
isfile(file) || getdata(dir)
3030
readdata(open(read,file))
3131
end
3232

3333
function testdata(dir=defdir)
34-
file = joinpath(dir, "cifar-100-binary/test.bin")
34+
file = joinpath(dir, "cifar-100-binary","test.bin")
3535
isfile(file) || getdata(dir)
3636
readdata(open(read,file))
3737
end

src/FashionMNIST/FashionMNIST.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
export FashionMNIST
2+
module FashionMNIST
3+
using ImageCore
4+
using ColorTypes
5+
6+
export
7+
8+
traintensor,
9+
testtensor,
10+
11+
trainlabels,
12+
testlabels,
13+
14+
traindata,
15+
testdata,
16+
17+
convert2image,
18+
convert2features,
19+
20+
download_helper
21+
22+
const DEFAULT_DIR = abspath(joinpath(dirname(@__FILE__), "..", "..", "datasets", "fashion_mnist"))
23+
24+
include(joinpath("Reader","Reader.jl"))
25+
import .Reader.download_helper
26+
include("interface.jl")
27+
include(joinpath("..", "MNIST", "utils.jl"))
28+
29+
Reader.download_helper(; nargs...) = Reader.download_helper(DEFAULT_DIR; nargs...)
30+
end

src/FashionMNIST/Reader/Reader.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
module Reader
2+
using GZip
3+
using BinDeps
4+
5+
export
6+
7+
readtrainimages,
8+
readtestimages,
9+
readtrainimages,
10+
readtestlabels,
11+
12+
download_helper
13+
14+
# Constants
15+
16+
const IMAGEOFFSET = 16
17+
const LABELOFFSET = 8
18+
19+
const TRAINIMAGES = "train-images-idx3-ubyte.gz"
20+
const TRAINLABELS = "train-labels-idx1-ubyte.gz"
21+
const TESTIMAGES = "t10k-images-idx3-ubyte.gz"
22+
const TESTLABELS = "t10k-labels-idx1-ubyte.gz"
23+
24+
# Includes
25+
26+
include("readheader.jl")
27+
include("readimages.jl")
28+
include("readlabels.jl")
29+
include("download.jl")
30+
end

src/FashionMNIST/Reader/download.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
msg_notfound(dir, filename) = "The FashionMNIST file \"$filename\" was not found in \"$dir\".
2+
You can download the dataset at https://github.com/zalandoresearch/fashion-mnist,
3+
or alternatively use FashionMNIST.download_helper(directory) to do it for you."
4+
5+
msg_prompt(dir, files) = """
6+
Interactive session detected. FashionMNIST.download_helper initiated.
7+
8+
Dataset: THE FashionMNIST DATABASE of fashion products
9+
Authors: Han Xiao, Kashif Rasul, Roland Vollgraf
10+
Website: https://github.com/zalandoresearch/fashion-mnist
11+
12+
Paper: Han Xiao, Kashif Rasul, Roland Vollgraf "Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms."
13+
14+
The specified directory \"$dir\" is missing the files $(join(map(f->"\"$f\"", files), ", ", " and ")) of the full data set.
15+
16+
The files are available for download at the offical website linked above.
17+
We can download these files for you if you wish.
18+
You want to download the dataset to \"$dir\"? [y/n] """
19+
20+
function downloaded_file(dir, filename)
21+
path = joinpath(dir, filename)
22+
if !isfile(path)
23+
if isinteractive()
24+
warn(msg_notfound(dir, filename))
25+
download_helper(dir)
26+
else
27+
error(msg_notfound(dir, filename))
28+
end
29+
end
30+
path
31+
end
32+
33+
"""
34+
download_helper(dir; i_accept_the_terms_of_use = true)
35+
36+
Check if the FashionMNIST dataset is contained in the specified `dir`,
37+
or if any of the four files are missing. If `dir` is omitted it
38+
will default to `MLDatasets/datasets/fashion_mnist`.
39+
40+
In the case that any of the four files is missing the user will be presented
41+
with the option to download it to the specified `dir`.
42+
"""
43+
function download_helper(dir; i_accept_the_terms_of_use = true)
44+
files = filter(file->!isfile(joinpath(dir, file)),
45+
[TRAINIMAGES, TRAINLABELS, TESTIMAGES, TESTLABELS])
46+
if !isempty(files)
47+
if !i_accept_the_terms_of_use && isinteractive()
48+
print(msg_prompt(dir, files))
49+
answer = first(readline())
50+
if answer == 'y'
51+
i_accept_the_terms_of_use = true
52+
end
53+
end
54+
if i_accept_the_terms_of_use
55+
mkpath(dir)
56+
for file in files
57+
url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file"
58+
path = joinpath(dir, file)
59+
info("downloading $file from $url to $dir")
60+
run(download_cmd(url, path))
61+
end
62+
else
63+
error("Unable to download the dataset. Please visit the website at https://github.com/zalandoresearch/fashion-mnist and download the files manually.")
64+
end
65+
else
66+
info("Nothing to do.")
67+
end
68+
nothing
69+
end

src/FashionMNIST/Reader/readheader.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
readimageheader(io::IO)
3+
4+
Reads four 32 bit integers at the current position of `io` and
5+
interprets them as a FashionMNIST-image-file header, which is described
6+
in detail in the table below
7+
8+
║ First │ Second │ Third │ Fourth
9+
════════╬══════════════╪══════════╪═════════╪════════════
10+
offset ║ 0000 │ 0004 │ 0008 │ 0012
11+
descr ║ magic number │ # images │ # rows │ # columns
12+
13+
These four numbers are returned as a Tuple in the same storage order
14+
"""
15+
function readimageheader(io::IO)
16+
magic_number = bswap(read(io, UInt32))
17+
total_items = bswap(read(io, UInt32))
18+
nrows = bswap(read(io, UInt32))
19+
ncols = bswap(read(io, UInt32))
20+
UInt32(magic_number), Int(total_items), Int(nrows), Int(ncols)
21+
end
22+
23+
"""
24+
readimageheader(file::AbstractString)
25+
26+
Opens and reads the first four 32 bits values of `file` and
27+
returns them interpreted as an FashionMNIST-image-file header
28+
"""
29+
function readimageheader(file::AbstractString)
30+
gzopen(readimageheader, file, "r")::Tuple{UInt32,Int,Int,Int}
31+
end
32+
33+
"""
34+
readlabelheader(io::IO)
35+
36+
Reads two 32 bit integers at the current position of `io` and
37+
interprets them as a FashionMNIST-label-file header, which consists of a
38+
*magic number* and the *total number of labels* stored in the
39+
file. These two numbers are returned as a Tuple in the same
40+
storage order.
41+
"""
42+
function readlabelheader(io::IO)
43+
magic_number = bswap(read(io, UInt32))
44+
total_items = bswap(read(io, UInt32))
45+
UInt32(magic_number), Int(total_items)
46+
end
47+
48+
"""
49+
readlabelheader(file::AbstractString)
50+
51+
Opens and reads the first two 32 bits values of `file` and
52+
returns them interpreted as an FashionMNIST-label-file header
53+
"""
54+
function readlabelheader(file::AbstractString)
55+
gzopen(readlabelheader, file, "r")::Tuple{UInt32,Int}
56+
end

src/FashionMNIST/Reader/readimages.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
function readimages!(buffer::Matrix{UInt8}, io::IO, index::Integer, nrows::Integer, ncols::Integer)
2+
seek(io, IMAGEOFFSET + nrows * ncols * (index - 1))
3+
read!(io, buffer)
4+
end
5+
6+
"""
7+
readimages(io::IO, index::Integer, nrows::Integer, ncols::Integer)
8+
9+
Jumps to the position of `io` where the bytes for the `index`'th
10+
image are located and reads the next `nrows` * `ncols` bytes. The
11+
read bytes are returned as a `Matrix{UInt8}` of size `(nrows, ncols)`.
12+
"""
13+
function readimages(io::IO, index::Integer, nrows::Integer, ncols::Integer)
14+
buffer = Array{UInt8}(nrows, ncols)
15+
readimages!(buffer, io, index, nrows, ncols)
16+
end
17+
18+
"""
19+
readimages(io::IO, indices::AbstractVector, nrows::Integer, ncols::Integer)
20+
21+
Reads the first `nrows` * `ncols` bytes for each image index in
22+
`indices` and stores them in a `Array{UInt8,3}` of size `(nrows,
23+
ncols, length(indices))` in the same order as denoted by
24+
`indices`.
25+
"""
26+
function readimages(io::IO, indices::AbstractVector, nrows::Integer, ncols::Integer)
27+
images = Array{UInt8}(nrows, ncols, length(indices))
28+
buffer = Array{UInt8}(nrows, ncols)
29+
dst_index = 1
30+
for src_index in indices
31+
readimages!(buffer, io, src_index, nrows, ncols)
32+
copy!(images, 1 + nrows * ncols * (dst_index - 1), buffer, 1, nrows * ncols)
33+
dst_index += 1
34+
end
35+
images
36+
end
37+
38+
"""
39+
readimages(file, [indices])
40+
41+
Reads the images denoted by `indices` from `file`. The given
42+
`file` can either be specified using an IO-stream or a string
43+
that denotes the fully qualified path. The conent of `file` is
44+
assumed to be in the FashionMNIST image-file format, as it is described
45+
on the official homepage at http://yann.lecun.com/exdb/FashionMNIST/
46+
47+
- if `indices` is an `Integer`, the single image is returned as
48+
`Matrix{UInt8}` in horizontal major layout, which means that
49+
the first dimension denotes the pixel *rows* (x), and the
50+
second dimension denotes the pixel *columns* (y) of the image.
51+
52+
- if `indices` is a `AbstractVector`, the images are returned as
53+
a 3D array (i.e. a `Array{UInt8,3}`), in which the first
54+
dimension corresponds to the pixel *rows* (x) of the image, the
55+
second dimension to the pixel *columns* (y) of the image, and
56+
the third dimension denotes the index of the image.
57+
58+
- if `indices` is ommited all images are returned
59+
(as 3D array described above)
60+
"""
61+
function readimages(io::IO, indices)
62+
_, nimages, nrows, ncols = readimageheader(io)
63+
@assert minimum(indices) >= 1 && maximum(indices) <= nimages
64+
readimages(io, indices, nrows, ncols)
65+
end
66+
67+
function readimages(file::AbstractString, index::Integer)
68+
gzopen(file, "r") do io
69+
readimages(io, index)
70+
end::Matrix{UInt8}
71+
end
72+
73+
function readimages(file::AbstractString, indices::AbstractVector)
74+
gzopen(file, "r") do io
75+
readimages(io, indices)
76+
end::Array{UInt8,3}
77+
end
78+
79+
function readimages(file::AbstractString)
80+
gzopen(file, "r") do io
81+
_, nimages, nrows, ncols = readimageheader(io)
82+
readimages(io, 1:nimages, nrows, ncols)
83+
end::Array{UInt8,3}
84+
end
85+
86+
"""
87+
readtrainimages(dir::AbstractString, [indices]) -> Array{UInt8}
88+
89+
Reads the images of the given `indices` from the file
90+
\"$TRAINIMAGES\" using the function [`readimages`](@ref), and
91+
returns them as a multi-dimensional array of `UInt8`.
92+
93+
The file is expected to be located in the given directory `dir`.
94+
In the case that the file does not yet exist the function will
95+
provide you with download instructions.
96+
"""
97+
readtrainimages(dir) = readimages(downloaded_file(dir, TRAINIMAGES))
98+
readtrainimages(dir, indices) = readimages(downloaded_file(dir, TRAINIMAGES), indices)
99+
100+
"""
101+
readtestimages(dir::AbstractString, [indices]) -> Array{UInt8}
102+
103+
Reads the images of the given `indices` from the file
104+
\"$TESTIMAGES\" using the function [`readimages`](@ref), and
105+
returns them as a multi-dimensional array of `UInt8`.
106+
107+
The file is expected to be located in the given directory `dir`.
108+
In the case that the file does not yet exist the function will
109+
provide you with download instructions.
110+
"""
111+
readtestimages(dir) = readimages(downloaded_file(dir, TESTIMAGES))
112+
readtestimages(dir, indices) = readimages(downloaded_file(dir, TESTIMAGES), indices)

0 commit comments

Comments
 (0)