|
| 1 | +module FashionMNIST |
| 2 | + |
| 3 | +using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel |
| 4 | + |
| 5 | +const dir = joinpath(@__DIR__, "../../deps/fashion-mnist") |
| 6 | + |
| 7 | +function load() |
| 8 | + mkpath(dir) |
| 9 | + cd(dir) do |
| 10 | + for file in ["train-images-idx3-ubyte", |
| 11 | + "train-labels-idx1-ubyte", |
| 12 | + "t10k-images-idx3-ubyte", |
| 13 | + "t10k-labels-idx1-ubyte"] |
| 14 | + isfile(file) && continue |
| 15 | + @info "Downloading Fashion-MNIST dataset" |
| 16 | + download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz") |
| 17 | + open(file, "w") do io |
| 18 | + write(io, gzopen(read, "$file.gz")) |
| 19 | + end |
| 20 | + end |
| 21 | + end |
| 22 | +end |
| 23 | + |
| 24 | +const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte") |
| 25 | +const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte") |
| 26 | +const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte") |
| 27 | +const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte") |
| 28 | + |
| 29 | +""" |
| 30 | + images() |
| 31 | + images(:test) |
| 32 | +
|
| 33 | +Load the Fashion-MNIST images. |
| 34 | +
|
| 35 | +Each image is a 28×28 array of `Gray` colour values (see Colors.jl). |
| 36 | +
|
| 37 | +Returns the 60,000 training images by default; pass `:test` to retreive the |
| 38 | +10,000 test images. |
| 39 | +""" |
| 40 | +function images(set = :train) |
| 41 | + load() |
| 42 | + io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES)) |
| 43 | + _, N, nrows, ncols = imageheader(io) |
| 44 | + [rawimage(io) for _ in 1:N] |
| 45 | +end |
| 46 | + |
| 47 | +""" |
| 48 | + labels() |
| 49 | + labels(:test) |
| 50 | +
|
| 51 | +Load the labels corresponding to each of the images returned from `images()`. |
| 52 | +Each label is a number from 0-9. |
| 53 | +
|
| 54 | +Returns the 60,000 training labels by default; pass `:test` to retreive the |
| 55 | +10,000 test labels. |
| 56 | +""" |
| 57 | +function labels(set = :train) |
| 58 | + load() |
| 59 | + io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS)) |
| 60 | + _, N = labelheader(io) |
| 61 | + [rawlabel(io) for _ = 1:N] |
| 62 | +end |
| 63 | + |
| 64 | +end |
0 commit comments