Skip to content

Commit 303f14f

Browse files
committed
add keyword argument black_digits
1 parent 79aed4d commit 303f14f

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/MNIST/utils.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
2-
convert2image(array) -> Array{Gray}
2+
convert2image(array; black_digits=false) -> Array{Gray}
33
44
Convert the given MNIST horizontal-major tensor (or feature matrix)
5-
to a vertical-major `Colorant` array. The values are also color
5+
to a vertical-major `Colorant` array. If `black_digits` is `true`, the values are also color
66
corrected according to the website's description, which means that
77
the digits are black on a white background.
88
@@ -16,17 +16,23 @@ julia> MNIST.convert2image(MNIST.traintensor(1)) # first training image
1616
[...]
1717
```
1818
"""
19-
function convert2image(array::AbstractArray{T}) where {T<:Number}
19+
function convert2image(array::AbstractArray{T}; black_digits::Bool=false) where {T<:Number}
2020
nlast = size(array)[end]
2121
array = reshape(array, 28, 28, :)
2222
array = permutedims(array, (2, 1, 3))
2323
if size(array)[end] == 1 && nlast != 1
2424
array = dropdims(array, dims=3)
2525
end
2626
if any(x -> x > 1, array) # simple check if x in [0,1]
27-
img = _colorview(Gray, array ./ T(255))
27+
array = array ./ T(255) # avoid changing the input array
28+
if black_digits
29+
array .= one(eltype(array)) .- array
30+
end
2831
else
29-
img = _colorview(Gray, array)
32+
if black_digits
33+
array = one(eltype(array)) .- array
34+
end
3035
end
31-
return one(eltype(img)) .- img
36+
37+
return _colorview(Gray, array)
3238
end

0 commit comments

Comments
 (0)