Skip to content

Commit d184e5c

Browse files
Merge pull request #83 from ShuhuaGao/MNIST-display
render black digits on a whiteboard
2 parents 3e15931 + 055218e commit d184e5c

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

src/MNIST/utils.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
2-
convert2image(array) -> Array{Gray}
2+
convert2image(array; black_digits=false) -> Array{Gray}
33
44
Convert the given MNIST horizontal-major tensor (or feature matrix)
5-
to a vertical-major `Colorant` array. The values are also color
5+
to a vertical-major `Colorant` array. If `black_digits` is `true`, the values are also color
66
corrected according to the website's description, which means that
77
the digits are black on a white background.
88
@@ -16,17 +16,23 @@ julia> MNIST.convert2image(MNIST.traintensor(1)) # first training image
1616
[...]
1717
```
1818
"""
19-
function convert2image(array::AbstractArray{T}) where {T<:Number}
19+
function convert2image(array::AbstractArray{T}; black_digits::Bool=false) where {T<:Number}
2020
nlast = size(array)[end]
2121
array = reshape(array, 28, 28, :)
2222
array = permutedims(array, (2, 1, 3))
2323
if size(array)[end] == 1 && nlast != 1
2424
array = dropdims(array, dims=3)
2525
end
2626
if any(x -> x > 1, array) # simple check if x in [0,1]
27-
img = _colorview(Gray, array ./ T(255))
27+
array = array ./ T(255) # avoid changing the input array
28+
if black_digits
29+
array .= one(eltype(array)) .- array
30+
end
2831
else
29-
img = _colorview(Gray, array)
32+
if black_digits
33+
array = one(eltype(array)) .- array
34+
end
3035
end
31-
img
36+
37+
return _colorview(Gray, array)
3238
end

test/tst_mnist.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ end
4141
@test size(A) == (28,28,2)
4242
@test eltype(A) == Gray{N0f8}
4343
@test MNIST.convert2image(vec(data)) == A
44+
45+
# test black digits and white background
46+
data = rand(N0f8,28,28,2)
47+
data[1] = 0
48+
data[3, 3, 2] = 0.4
49+
A = MNIST.convert2image(data; black_digits=true)
50+
@test A[1] == 1
51+
@test A[3, 3, 2] == 0.6
52+
@test size(A) == (28,28,2)
53+
@test eltype(A) == Gray{N0f8}
54+
@test MNIST.convert2image(vec(data); black_digits=true) == A
4455
end
4556

4657
# NOT executed on CI. only executed locally.

0 commit comments

Comments
 (0)