Skip to content

Commit f988a12

Browse files
Add Omniglot Dataset (#152)
* Add Omniglot dataset * Fix typo * Add tests * Fix more typos * Fix tests and documentation * Fix test * Omniglot review fixes * Fix Omniglot test Co-authored-by: Christian <[email protected]>
1 parent 66edbf7 commit f988a12

File tree

5 files changed

+252
-0
lines changed

5 files changed

+252
-0
lines changed

docs/src/datasets/vision.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ CIFAR100
2525
EMNIST
2626
FashionMNIST
2727
MNIST
28+
Omniglot
2829
SVHN2
2930
```

src/MLDatasets.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ include("datasets/vision/cifar100.jl")
8686
export CIFAR100
8787
include("datasets/vision/svhn2.jl")
8888
export SVHN2
89+
include("datasets/vision/omniglot.jl")
90+
export Omniglot
8991

9092
## Text
9193

@@ -149,6 +151,7 @@ function __init__()
149151
__init__emnist()
150152
__init__fashionmnist()
151153
__init__mnist()
154+
__init__omniglot()
152155
__init__svhn2()
153156
end
154157

src/datasets/vision/omniglot.jl

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
function __init__omniglot()
2+
DEPNAME = "Omniglot"
3+
TRAIN = "data_background.mat"
4+
TEST = "data_evaluation.mat"
5+
SMALL1 = "data_background_small1.mat"
6+
SMALL2 = "data_background_small2.mat"
7+
8+
register(DataDep(
9+
DEPNAME,
10+
"""
11+
Dataset: Omniglot data set for one-shot learning
12+
Authors: Brenden M. Lake, Ruslan Salakhutdinov, Joshua B. Tenenbaum
13+
Website: https://github.com/brendenlake/omniglot
14+
15+
[Lake et al., 2015]
16+
Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015).
17+
Human-level concept learning through probabilistic program induction.
18+
Science, 350(6266), 1332-1338.
19+
20+
The files are available for download at the official
21+
github repository linked above. Note that using the data
22+
responsibly and respecting copyright remains your
23+
responsibility. The authors of Omniglot aren't really
24+
explicit about any terms of use, so please read the
25+
website to make sure you want to download the dataset.
26+
""",
27+
"https://github.com/brendenlake/omniglot/raw/master/matlab/" .* [TRAIN, TEST, SMALL1, SMALL2],
28+
"1cfb52d931d794a3dd2717da6c80ddb8f48b0a6f559916c60fcdcd908b65d3d8", #matlab links
29+
))
30+
end
31+
32+
33+
"""
34+
Omniglot(; Tx=Float32, split=:train, dir=nothing)
35+
Omniglot([Tx, split])
36+
37+
Omniglot data set for one-shot learning
38+
39+
- Authors: Brenden M. Lake, Ruslan Salakhutdinov, Joshua B. Tenenbaum
40+
- Website: https://github.com/brendenlake/omniglot
41+
42+
The Omniglot data set is designed for developing more human-like learning
43+
algorithms. It contains 1623 different handwritten characters from 50 different
44+
alphabets. Each of the 1623 characters was drawn online via Amazon's
45+
Mechanical Turk by 20 different people. Each image is paired with stroke data, a
46+
sequences of [x,y,t] coordinates with time (t) in milliseconds.
47+
48+
# Arguments
49+
50+
$ARGUMENTS_SUPERVISED_ARRAY
51+
- `split`: selects the data partition. Can take the values `:train`, `:test`, `:small1`, or `:small2`.
52+
53+
# Fields
54+
55+
$FIELDS_SUPERVISED_ARRAY
56+
- `split`.
57+
58+
# Methods
59+
60+
$METHODS_SUPERVISED_ARRAY
61+
- [`convert2image`](@ref) converts features to `Gray` images.
62+
63+
# Examples
64+
65+
The images are loaded as a multi-dimensional array of eltype `Tx`.
66+
All values will be `0` or `1`. `Omniglot().features` is a 3D array
67+
(i.e. a `Array{Tx,3}`), in WHN format (width, height, num_images).
68+
Labels are stored as a vector of strings in `Omniglot().targets`.
69+
70+
```julia-repl
71+
julia> using MLDatasets: Omniglot
72+
73+
julia> dataset = Omniglot(:train)
74+
Omniglot:
75+
metadata => Dict{String, Any} with 3 entries
76+
split => :train
77+
features => 105×105×19280 Array{Float32, 3}
78+
targets => 19280-element Vector{Int64}
79+
80+
julia> dataset[1:5].targets
81+
5-element Vector{String}:
82+
"Arcadian"
83+
"Arcadian"
84+
"Arcadian"
85+
"Arcadian"
86+
"Arcadian"
87+
88+
julia> X, y = dataset[:];
89+
90+
julia> dataset = Omniglot(UInt8, :test)
91+
Omniglot:
92+
metadata => Dict{String, Any} with 3 entries
93+
split => :test
94+
features => 105×105×13180 Array{UInt8, 3}
95+
targets => 13180-element Vector{Int64}
96+
```
97+
"""
98+
struct Omniglot <: SupervisedDataset
99+
metadata::Dict{String, Any}
100+
split::Symbol
101+
features::Array{<:Any,3}
102+
targets::Vector{String}
103+
end
104+
105+
Omniglot(; split=:train, Tx=Float32, dir=nothing) = Omniglot(Tx, split; dir)
106+
Omniglot(split::Symbol; kws...) = Omniglot(; split, kws...)
107+
Omniglot(Tx::Type; kws...) = Omniglot(; Tx, kws...)
108+
109+
function Omniglot(Tx::Type, split::Symbol; dir=nothing)
110+
@assert split [:train, :test, :small1, :small2]
111+
if split === :train
112+
IMAGESPATH = "data_background.mat"
113+
elseif split === :test
114+
IMAGESPATH = "data_evaluation.mat"
115+
elseif split === :small1
116+
IMAGESPATH = "data_background_small1.mat"
117+
elseif split === :small2
118+
IMAGESPATH = "data_background_small2.mat"
119+
end
120+
121+
filename = datafile("Omniglot", IMAGESPATH, dir)
122+
123+
file = MAT.matopen(filename)
124+
images = MAT.read(file, "images")
125+
names = MAT.read(file, "names")
126+
MAT.close(file)
127+
128+
image_count = 0
129+
for alphabet in images
130+
for character in alphabet
131+
image_count += length(character)
132+
end
133+
end
134+
135+
features = Array{Tx}(undef, 105, 105, image_count)
136+
targets = Vector{String}(undef, image_count)
137+
138+
curr_idx = 1
139+
for i in range(1, length(images))
140+
alphabet = images[i]
141+
for character in alphabet
142+
for variation in character
143+
targets[curr_idx] = names[i]
144+
features[:,:,curr_idx] = variation
145+
curr_idx += 1
146+
end
147+
end
148+
end
149+
150+
metadata = Dict{String,Any}()
151+
metadata["n_observations"] = size(features)[end]
152+
metadata["features_path"] = IMAGESPATH
153+
metadata["targets_path"] = IMAGESPATH
154+
155+
return Omniglot(metadata, split, features, targets)
156+
end
157+
158+
159+
160+
convert2image(::Type{<:Omniglot}, x::AbstractArray{<:Integer}) =
161+
convert2image(Omniglot, reinterpret(N0f8, convert(Array{UInt8}, x)))
162+
163+
function convert2image(::Type{<:Omniglot}, x::AbstractArray{T,N}) where {T,N}
164+
@assert N == 2 || N == 3
165+
x = permutedims(x, (2, 1, 3:N...))
166+
ImageCore = ImageShow.ImageCore
167+
return ImageCore.colorview(ImageCore.Gray, x)
168+
end

test/datasets/vision/omniglot.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
n_features = (105, 105)
2+
n_targets = 1
3+
4+
@testset "trainset" begin
5+
d = Omniglot()
6+
7+
@test d.split == :train
8+
@test extrema(d.features) == (0, 1)
9+
@test convert2image(d, 1) isa AbstractMatrix{<:Gray}
10+
@test convert2image(d, 1:2) isa AbstractArray{<:Gray, 3}
11+
12+
test_supervised_array_dataset(d;
13+
n_features, n_targets, n_obs=19280,
14+
Tx=Float32, Ty=String,
15+
conv2img=true)
16+
17+
d = Omniglot(:train)
18+
@test d.split == :train
19+
d = Omniglot(Int)
20+
@test d.split == :train
21+
@test d.features isa Array{Int}
22+
end
23+
24+
@testset "testset" begin
25+
d = Omniglot(split=:test, Tx=UInt8)
26+
27+
@test d.split == :test
28+
@test extrema(d.features) == (0, 1)
29+
@test convert2image(d, 1) isa AbstractMatrix{<:Gray}
30+
31+
test_supervised_array_dataset(d;
32+
n_features, n_targets, n_obs=13180,
33+
Tx=UInt8, Ty=String,
34+
conv2img=true)
35+
36+
d = Omniglot(:test)
37+
@test d.split == :test
38+
d = Omniglot(Int, :test)
39+
@test d.split == :test
40+
@test d.features isa Array{Int}
41+
end
42+
43+
@testset "small1set" begin
44+
d = Omniglot(split=:small1, Tx=Float32)
45+
46+
@test d.split == :small1
47+
@test extrema(d.features) == (0, 1)
48+
@test convert2image(d, 1) isa AbstractMatrix{<:Gray}
49+
50+
test_supervised_array_dataset(d;
51+
n_features, n_targets, n_obs=2720,
52+
Tx=Float32, Ty=String,
53+
conv2img=true)
54+
55+
d = Omniglot(:small1)
56+
@test d.split == :small1
57+
d = Omniglot(Int, :small1)
58+
@test d.split == :small1
59+
@test d.features isa Array{Int}
60+
end
61+
62+
@testset "small2set" begin
63+
d = Omniglot(split=:small2, Tx=UInt8)
64+
65+
@test d.split == :small2
66+
@test extrema(d.features) == (0, 1)
67+
@test convert2image(d, 1) isa AbstractMatrix{<:Gray}
68+
69+
test_supervised_array_dataset(d;
70+
n_features, n_targets, n_obs=3120,
71+
Tx=UInt8, Ty=String,
72+
conv2img=true)
73+
74+
d = Omniglot(:small2)
75+
@test d.split == :small2
76+
d = Omniglot(Int, :small2)
77+
@test d.split == :small2
78+
@test d.features isa Array{Int}
79+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ no_ci_dataset_tests = [
2525
"datasets/vision/cifar10.jl",
2626
"datasets/vision/cifar100.jl",
2727
"datasets/vision/emnist.jl",
28+
"datasets/vision/omniglot.jl",
2829
"datasets/vision/svhn2.jl",
2930
]
3031

0 commit comments

Comments
 (0)