Skip to content

Commit 73a526b

Browse files
committed
reuse utils from mnist.jl
1 parent 95d72d7 commit 73a526b

File tree

1 file changed

+1
-52
lines changed

1 file changed

+1
-52
lines changed

src/data/fashion-mnist.jl

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
module FashionMNIST
22

3-
using CodecZlib, Colors
4-
5-
const Gray = Colors.Gray{Colors.N0f8}
3+
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
64

75
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
86

9-
function gzopen(f, file)
10-
open(file) do io
11-
f(GzipDecompressorStream(io))
12-
end
13-
end
14-
157
function load()
168
mkpath(dir)
179
cd(dir) do
@@ -29,53 +21,11 @@ function load()
2921
end
3022
end
3123

32-
const IMAGEOFFSET = 16
33-
const LABELOFFSET = 8
34-
35-
const NROWS = 28
36-
const NCOLS = 28
37-
3824
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
3925
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
4026
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
4127
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
4228

43-
function imageheader(io::IO)
44-
magic_number = bswap(read(io, UInt32))
45-
total_items = bswap(read(io, UInt32))
46-
nrows = bswap(read(io, UInt32))
47-
ncols = bswap(read(io, UInt32))
48-
return magic_number, Int(total_items), Int(nrows), Int(ncols)
49-
end
50-
51-
function labelheader(io::IO)
52-
magic_number = bswap(read(io, UInt32))
53-
total_items = bswap(read(io, UInt32))
54-
return magic_number, Int(total_items)
55-
end
56-
57-
function rawimage(io::IO)
58-
img = Array{Gray}(undef, NCOLS, NROWS)
59-
for i in 1:NCOLS, j in 1:NROWS
60-
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
61-
end
62-
return img
63-
end
64-
65-
function rawimage(io::IO, index::Integer)
66-
seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1))
67-
return rawimage(io)
68-
end
69-
70-
rawlabel(io::IO) = Int(read(io, UInt8))
71-
72-
function rawlabel(io::IO, index::Integer)
73-
seek(io, LABELOFFSET + (index - 1))
74-
return rawlabel(io)
75-
end
76-
77-
getfeatures(io::IO, index::Integer) = vec(getimage(io, index))
78-
7929
"""
8030
images()
8131
images(:test)
@@ -111,5 +61,4 @@ function labels(set = :train)
11161
[rawlabel(io) for _ = 1:N]
11262
end
11363

114-
11564
end

0 commit comments

Comments
 (0)