|
| 1 | +function __init__omniglot() |
| 2 | + DEPNAME = "Omniglot" |
| 3 | + TRAIN = "data_background.mat" |
| 4 | + TEST = "data_evaluation.mat" |
| 5 | + SMALL1 = "data_background_small1.mat" |
| 6 | + SMALL2 = "data_background_small2.mat" |
| 7 | + |
| 8 | + register(DataDep( |
| 9 | + DEPNAME, |
| 10 | + """ |
| 11 | + Dataset: Omniglot data set for one-shot learning |
| 12 | + Authors: Brenden M. Lake, Ruslan Salakhutdinov, Joshua B. Tenenbaum |
| 13 | + Website: https://github.com/brendenlake/omniglot |
| 14 | +
|
| 15 | + [Lake et al., 2015] |
| 16 | + Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). |
| 17 | + Human-level concept learning through probabilistic program induction. |
| 18 | + Science, 350(6266), 1332-1338. |
| 19 | +
|
| 20 | + The files are available for download at the official |
| 21 | + github repository linked above. Note that using the data |
| 22 | + responsibly and respecting copyright remains your |
| 23 | + responsibility. The authors of Omniglot aren't really |
| 24 | + explicit about any terms of use, so please read the |
| 25 | + website to make sure you want to download the dataset. |
| 26 | + """, |
| 27 | + "https://github.com/brendenlake/omniglot/raw/master/matlab/" .* [TRAIN, TEST, SMALL1, SMALL2], |
| 28 | + "1cfb52d931d794a3dd2717da6c80ddb8f48b0a6f559916c60fcdcd908b65d3d8", #matlab links |
| 29 | + )) |
| 30 | +end |
| 31 | + |
| 32 | + |
| 33 | +""" |
| 34 | + Omniglot(; Tx=Float32, split=:train, dir=nothing) |
| 35 | + Omniglot([Tx, split]) |
| 36 | +
|
| 37 | +Omniglot data set for one-shot learning |
| 38 | +
|
| 39 | +- Authors: Brenden M. Lake, Ruslan Salakhutdinov, Joshua B. Tenenbaum |
| 40 | +- Website: https://github.com/brendenlake/omniglot |
| 41 | +
|
| 42 | +The Omniglot data set is designed for developing more human-like learning |
| 43 | +algorithms. It contains 1623 different handwritten characters from 50 different |
| 44 | +alphabets. Each of the 1623 characters was drawn online via Amazon's |
| 45 | +Mechanical Turk by 20 different people. Each image is paired with stroke data, a |
| 46 | +sequences of [x,y,t] coordinates with time (t) in milliseconds. |
| 47 | +
|
| 48 | +# Arguments |
| 49 | +
|
| 50 | +$ARGUMENTS_SUPERVISED_ARRAY |
| 51 | +- `split`: selects the data partition. Can take the values `:train`, `:test`, `:small1`, or `:small2`. |
| 52 | +
|
| 53 | +# Fields |
| 54 | +
|
| 55 | +$FIELDS_SUPERVISED_ARRAY |
| 56 | +- `split`. |
| 57 | +
|
| 58 | +# Methods |
| 59 | +
|
| 60 | +$METHODS_SUPERVISED_ARRAY |
| 61 | +- [`convert2image`](@ref) converts features to `Gray` images. |
| 62 | +
|
| 63 | +# Examples |
| 64 | +
|
| 65 | +The images are loaded as a multi-dimensional array of eltype `Tx`. |
| 66 | +All values will be `0` or `1`. `Omniglot().features` is a 3D array |
| 67 | +(i.e. a `Array{Tx,3}`), in WHN format (width, height, num_images). |
| 68 | +Labels are stored as a vector of strings in `Omniglot().targets`. |
| 69 | +
|
| 70 | +```julia-repl |
| 71 | +julia> using MLDatasets: Omniglot |
| 72 | +
|
| 73 | +julia> dataset = Omniglot(:train) |
| 74 | +Omniglot: |
| 75 | + metadata => Dict{String, Any} with 3 entries |
| 76 | + split => :train |
| 77 | + features => 105×105×19280 Array{Float32, 3} |
| 78 | + targets => 19280-element Vector{Int64} |
| 79 | +
|
| 80 | +julia> dataset[1:5].targets |
| 81 | +5-element Vector{String}: |
| 82 | + "Arcadian" |
| 83 | + "Arcadian" |
| 84 | + "Arcadian" |
| 85 | + "Arcadian" |
| 86 | + "Arcadian" |
| 87 | +
|
| 88 | +julia> X, y = dataset[:]; |
| 89 | +
|
| 90 | +julia> dataset = Omniglot(UInt8, :test) |
| 91 | +Omniglot: |
| 92 | + metadata => Dict{String, Any} with 3 entries |
| 93 | + split => :test |
| 94 | + features => 105×105×13180 Array{UInt8, 3} |
| 95 | + targets => 13180-element Vector{Int64} |
| 96 | +``` |
| 97 | +""" |
| 98 | +struct Omniglot <: SupervisedDataset |
| 99 | + metadata::Dict{String, Any} |
| 100 | + split::Symbol |
| 101 | + features::Array{<:Any,3} |
| 102 | + targets::Vector{String} |
| 103 | +end |
| 104 | + |
| 105 | +Omniglot(; split=:train, Tx=Float32, dir=nothing) = Omniglot(Tx, split; dir) |
| 106 | +Omniglot(split::Symbol; kws...) = Omniglot(; split, kws...) |
| 107 | +Omniglot(Tx::Type; kws...) = Omniglot(; Tx, kws...) |
| 108 | + |
| 109 | +function Omniglot(Tx::Type, split::Symbol; dir=nothing) |
| 110 | + @assert split ∈ [:train, :test, :small1, :small2] |
| 111 | + if split === :train |
| 112 | + IMAGESPATH = "data_background.mat" |
| 113 | + elseif split === :test |
| 114 | + IMAGESPATH = "data_evaluation.mat" |
| 115 | + elseif split === :small1 |
| 116 | + IMAGESPATH = "data_background_small1.mat" |
| 117 | + elseif split === :small2 |
| 118 | + IMAGESPATH = "data_background_small2.mat" |
| 119 | + end |
| 120 | + |
| 121 | + filename = datafile("Omniglot", IMAGESPATH, dir) |
| 122 | + |
| 123 | + file = MAT.matopen(filename) |
| 124 | + images = MAT.read(file, "images") |
| 125 | + names = MAT.read(file, "names") |
| 126 | + MAT.close(file) |
| 127 | + |
| 128 | + image_count = 0 |
| 129 | + for alphabet in images |
| 130 | + for character in alphabet |
| 131 | + image_count += length(character) |
| 132 | + end |
| 133 | + end |
| 134 | + |
| 135 | + features = Array{Tx}(undef, 105, 105, image_count) |
| 136 | + targets = Vector{String}(undef, image_count) |
| 137 | + |
| 138 | + curr_idx = 1 |
| 139 | + for i in range(1, length(images)) |
| 140 | + alphabet = images[i] |
| 141 | + for character in alphabet |
| 142 | + for variation in character |
| 143 | + targets[curr_idx] = names[i] |
| 144 | + features[:,:,curr_idx] = variation |
| 145 | + curr_idx += 1 |
| 146 | + end |
| 147 | + end |
| 148 | + end |
| 149 | + |
| 150 | + metadata = Dict{String,Any}() |
| 151 | + metadata["n_observations"] = size(features)[end] |
| 152 | + metadata["features_path"] = IMAGESPATH |
| 153 | + metadata["targets_path"] = IMAGESPATH |
| 154 | + |
| 155 | + return Omniglot(metadata, split, features, targets) |
| 156 | +end |
| 157 | + |
| 158 | + |
| 159 | + |
| 160 | +convert2image(::Type{<:Omniglot}, x::AbstractArray{<:Integer}) = |
| 161 | + convert2image(Omniglot, reinterpret(N0f8, convert(Array{UInt8}, x))) |
| 162 | + |
| 163 | +function convert2image(::Type{<:Omniglot}, x::AbstractArray{T,N}) where {T,N} |
| 164 | + @assert N == 2 || N == 3 |
| 165 | + x = permutedims(x, (2, 1, 3:N...)) |
| 166 | + ImageCore = ImageShow.ImageCore |
| 167 | + return ImageCore.colorview(ImageCore.Gray, x) |
| 168 | +end |
0 commit comments