Skip to content

Commit 325d2ce

Browse files
authored
Merge pull request #418 from c-p-murphy/add-fashion-mnist
Add FashionMNIST
2 parents 61fb6cd + 73a526b commit 325d2ce

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

src/data/Data.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ end
1313
include("mnist.jl")
1414
export MNIST
1515

16+
include("fashion-mnist.jl")
17+
export FashionMNIST
18+
1619
include("cmudict.jl")
1720
using .CMUDict
1821

src/data/fashion-mnist.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
module FashionMNIST
2+
3+
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
4+
5+
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
6+
7+
function load()
8+
mkpath(dir)
9+
cd(dir) do
10+
for file in ["train-images-idx3-ubyte",
11+
"train-labels-idx1-ubyte",
12+
"t10k-images-idx3-ubyte",
13+
"t10k-labels-idx1-ubyte"]
14+
isfile(file) && continue
15+
@info "Downloading Fashion-MNIST dataset"
16+
download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz")
17+
open(file, "w") do io
18+
write(io, gzopen(read, "$file.gz"))
19+
end
20+
end
21+
end
22+
end
23+
24+
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
25+
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
26+
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
27+
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
28+
29+
"""
30+
images()
31+
images(:test)
32+
33+
Load the Fashion-MNIST images.
34+
35+
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
36+
37+
Returns the 60,000 training images by default; pass `:test` to retreive the
38+
10,000 test images.
39+
"""
40+
function images(set = :train)
41+
load()
42+
io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES))
43+
_, N, nrows, ncols = imageheader(io)
44+
[rawimage(io) for _ in 1:N]
45+
end
46+
47+
"""
48+
labels()
49+
labels(:test)
50+
51+
Load the labels corresponding to each of the images returned from `images()`.
52+
Each label is a number from 0-9.
53+
54+
Returns the 60,000 training labels by default; pass `:test` to retreive the
55+
10,000 test labels.
56+
"""
57+
function labels(set = :train)
58+
load()
59+
io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS))
60+
_, N = labelheader(io)
61+
[rawlabel(io) for _ = 1:N]
62+
end
63+
64+
end

test/data.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,7 @@ using Test
1010
@test MNIST.images()[1] isa Matrix
1111
@test MNIST.labels() isa Vector{Int64}
1212

13+
@test FashionMNIST.images()[1] isa Matrix
14+
@test FashionMNIST.labels() isa Vector{Int64}
15+
1316
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}

0 commit comments

Comments
 (0)