Skip to content

Commit 50d4ee0

Browse files
remove numpy format when with_format("julia") (#15)
* work in prog * readme * fix docstring * no ci * cleanup
1 parent 6fecb5c commit 50d4ee0

File tree

15 files changed

+200
-93
lines changed

15 files changed

+200
-93
lines changed

CondaPkg.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
channels = ["conda-forge"]
22

33
[deps]
4+
h5py = ""
45
pillow = ">=9.1, <10"
56
numpy = ">=1.20, <2"
67
datasets = ">=2.7, <3"

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
name = "HuggingFaceDatasets"
22
uuid = "d94b9a45-fdf5-4270-b024-5cbb9ef7117d"
33
authors = ["Carlo Lucibello"]
4-
version = "0.2.1"
4+
version = "0.3.0"
55

66
[deps]
77
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
88
DLPack = "53c2dc0f-f7d5-43fd-8906-6c0220547083"
9+
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
910
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1011
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1112

1213
[compat]
1314
CondaPkg = "0.2"
1415
DLPack = "0.1"
15-
MLUtils = "0.2, 0.3, 0.4"
16-
PythonCall = "0.8, 0.9"
16+
ImageCore = "0.9"
17+
MLUtils = "0.4.1"
18+
PythonCall = "0.9"
1719
julia = "1.7"
1820

1921
[extras]

README.md

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,30 @@ Check out the `examples/` folder for usage examples.
2525

2626
```julia
2727
julia> train_data = load_dataset("mnist", split = "train")
28-
Dataset(<py Dataset({
28+
Dataset({
2929
features: ['image', 'label'],
3030
num_rows: 60000
31-
})>, identity)
31+
})
3232

3333
# Indexing starts with 1.
34-
# By defaul, python types are returned.
34+
# Python types are returned by default.
3535
julia> train_data[1]
3636
Python dict: {'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x2B64E2E90>, 'label': 5}
3737

38-
julia> set_format!(train_data, "julia")
39-
Dataset(<py Dataset({
40-
features: ['image', 'label'],
41-
num_rows: 60000
42-
})>, HuggingFaceDatasets.py2jl)
38+
julia> length(train_data)
39+
60000
4340

44-
# Now we have julia types
41+
# Now we set the julia format
42+
julia> train_data = load_dataset("mnist", split = "train").with_format("julia");
43+
44+
# Returned observations are julia objects
4545
julia> train_data[1]
4646
Dict{String, Any} with 2 entries:
4747
"label" => 5
48-
"image" => UInt8[0x00 0x00 0x00 0x00; 0x00 0x00 0x00 0x00; ; 0x00 0x00 0x00 0x00; 0x00 0x00 0x00 0x00]
48+
"image" => ColorTypes.Gray{FixedPointNumbers.N0f8}[Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0); Gray{N
49+
50+
julia> train_data[1:2]
51+
Dict{String, Vector} with 2 entries:
52+
"label" => [5, 0]
53+
"image" => Base.ReinterpretArray{Gray{N0f8}, 2, UInt8, Matrix{UInt8}, false}[[Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gra
4954
```

examples/flux_mnist.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,23 @@ using Random, Statistics
33
using Flux.Losses: logitcrossentropy
44
using Flux: onecold
55
using HuggingFaceDatasets
6+
using MLUtils
7+
using ImageCore
68
# using ProfileView, BenchmarkTools
79

8-
function mnist_transform(x)
9-
x = py2jl(x)
10-
image = x["image"] ./ 255f0
11-
label = Flux.onehotbatch(x["label"], 0:9)
10+
function mnist_transform(batch)
11+
image = ImageCore.channelview.(batch["image"]) # from Matrix{Gray{N0f8}} to Matrix{UInt8}
12+
image = Flux.batch(image) ./ 255f0
13+
label = Flux.onehotbatch(batch["label"], 0:9)
1214
return (; image, label)
1315
end
1416

17+
# Remove when https://github.com/JuliaML/MLUtils.jl/pull/147 is merged and tagged
18+
Base.getindex(data::MLUtils.MappedData, idx::Int) = getobs(data.f(getobs(data.data, [idx])), 1)
19+
Base.getindex(data::MLUtils.MappedData, idxs::AbstractVector) = data.f(getobs(data.data, idxs))
20+
Base.getindex(data::MLUtils.MappedData, ::Colon) = data[1:length(data.data)]
21+
22+
1523
function loss_and_accuracy(data_loader, model, device)
1624
acc = 0
1725
ls = 0.0f0
@@ -29,18 +37,16 @@ end
2937
function train(epochs)
3038
batchsize = 128
3139
nhidden = 100
32-
device = gpu
33-
34-
dataset = load_dataset("mnist")
35-
set_format!(dataset, "julia")
36-
set_jltransform!(dataset, mnist_transform)
37-
38-
# We use [:] to materialize and transform the whole dataset.
39-
# This gives much faster iterations.
40-
# Omit the [:] if you don't want to load the whole dataset in-memory.
41-
train_loader = Flux.DataLoader(dataset["train"][:]; batchsize, shuffle=true)
42-
test_loader = Flux.DataLoader(dataset["test"][:]; batchsize)
40+
device = cpu
4341

42+
train_data = load_dataset("mnist", split="train").with_format("julia")
43+
test_data = load_dataset("mnist", split="test").with_format("julia")
44+
train_data = mapobs(mnist_transform, train_data)[:] # lazy apply transform then materialize
45+
test_data = mapobs(mnist_transform, test_data)[:]
46+
47+
train_loader = Flux.DataLoader(train_data; batchsize, shuffle=true)
48+
test_loader = Flux.DataLoader(test_data; batchsize)
49+
4450
model = Chain([Flux.flatten,
4551
Dense(28*28, nhidden, relu),
4652
Dense(nhidden, nhidden, relu),
@@ -57,7 +63,7 @@ function train(epochs)
5763
end
5864

5965
report(0)
60-
for epoch in 1:epochs
66+
@time for epoch in 1:epochs
6167
for (x, y) in train_loader
6268
x, y = x |> device, y |> device
6369
loss, grads = withgradient(model -> logitcrossentropy(model(x), y), model)

perf/perf.jl

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,61 @@
11
using HuggingFaceDatasets
22
using BenchmarkTools
3+
using MLDatasets
34

45
function f(ds)
5-
for i in 1:6000
6-
ds[i]
6+
for i in 1:numobs(ds)
7+
getobs(ds, i)
78
end
89
end
10+
function fbatch(ds)
11+
for i in 1:128:numobs(ds)-128
12+
getobs(ds, i:i+127)
13+
end
14+
end
15+
function fall(ds)
16+
getobs(ds, :)
17+
end
918

10-
ds_plain = load_dataset("mnist", split="train")
11-
@btime f(ds_plain)
12-
13-
14-
ds_julia = with_format(ds_plain, "julia")
15-
@btime f(ds_plain)
16-
17-
18-
ds_py2jl = with_jltransform(ds_plain, py2jl)
19-
20-
set_transform!(ds2, py2jl)
21-
22-
@btime f1(ds2)
23-
24-
ds2[1]["image"]
25-
26-
set_transform!(ds2, identity)
27-
@time ds2[1:10000]["image"];
28-
@time Flux.batch(ds2[1:10000]["image"])
29-
30-
@time Flux.batch(ds2[1:10000]["image"] |> py2jl)
31-
32-
ds[1]
33-
34-
ds2["label"]
35-
36-
ds2.set_format("numpy")
37-
ds2[1:10] |> py2jl
38-
39-
ds2["image"]
40-
41-
#####
42-
function set_jltransform!(ds, transform = identity)
43-
ds.pyset_format("numpy")
44-
ds.jltransform(transform)
19+
function bench()
20+
mld = MNIST(split=:test)
21+
ds_plain = load_dataset("mnist", split="test")
22+
ds_julia = with_format(ds_plain, "julia")
23+
ds_numpy = with_format(ds_plain, "numpy")
24+
ds_jnumpy = with_jltransform(py2jl, ds_numpy) # numpy + py2jl
25+
26+
for (name, ds) in [("mldatasets", mld),
27+
("plain", ds_plain),
28+
("julia", ds_julia),
29+
("numpy", ds_numpy),
30+
("jnumpy", ds_jnumpy)]
31+
println("# $name")
32+
@btime f($ds)
33+
@btime fbatch($ds)
34+
@btime fall($ds)
35+
end
4536
end
37+
38+
# hf is slow at reading image datasets.
39+
# Pytorch vision is much faset (see the notebook in perf/)
40+
41+
bench()
42+
# # MLDatasets
43+
# 19.515 ms (120005 allocations: 34.64 MiB)
44+
# 4.671 ms (1097 allocations: 29.97 MiB)
45+
# 717.324 ns (6 allocations: 240 bytes)
46+
# # plain
47+
# 602.001 ms (668464 allocations: 18.06 MiB)
48+
# 266.483 ms (390 allocations: 6.09 KiB)
49+
# 265.651 ms (5 allocations: 80 bytes)
50+
# # julia
51+
# 985.251 ms (2398464 allocations: 93.28 MiB)
52+
# 379.270 ms (659256 allocations: 27.31 MiB)
53+
# 378.751 ms (650134 allocations: 27.01 MiB)
54+
# # numpy
55+
# 1.264 s (728464 allocations: 19.13 MiB)
56+
# 311.426 ms (390 allocations: 6.09 KiB)
57+
# 318.403 ms (5 allocations: 80 bytes)
58+
# # jnumpy
59+
# 1.527 s (2208464 allocations: 110.91 MiB)
60+
# 318.356 ms (13962 allocations: 637.41 KiB)
61+
# 335.109 ms (179 allocations: 8.17 KiB)

src/HuggingFaceDatasets.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using PythonCall
44
using MLUtils: getobs, numobs
55
import MLUtils
66
using DLPack
7+
using ImageCore
78

89
const datasets = PythonCall.pynew()
910
const PIL = PythonCall.pynew()

src/dataset.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ end
2828
function Base.getproperty(ds::Dataset, s::Symbol)
2929
if s in fieldnames(Dataset)
3030
return getfield(ds, s)
31+
elseif s === :with_format
32+
return format -> with_format(ds, format)
3133
else
3234
res = getproperty(getfield(ds, :pyds), s)
3335
if pycallable(res)
@@ -44,7 +46,7 @@ Base.getindex(ds::Dataset, ::Colon) = ds[1:length(ds)]
4446

4547
function Base.getindex(ds::Dataset, i::AbstractVector{<:Integer})
4648
@assert all(>(0), i)
47-
x = ds.pyds[i .- 1]
49+
x = getfield(ds, :pyds)[i .- 1]
4850
return ds.jltransform(x)
4951
end
5052

@@ -64,6 +66,8 @@ function Base.deepcopy(ds::Dataset)
6466
return Dataset(pyds, ds.jltransform)
6567
end
6668

69+
Base.show(io::IO, ds::Dataset) = print(io, ds.pyds)
70+
6771
"""
6872
with_format(ds::Dataset, format)
6973
@@ -103,7 +107,7 @@ version of [`with_format`](@ref).
103107
"""
104108
function set_format!(ds::Dataset, format)
105109
if format == "julia"
106-
ds.pyds.set_format("numpy")
110+
# ds.pyds.set_format("numpy")
107111
ds.jltransform = py2jl
108112
else
109113
ds.pyds.set_format(format)

src/datasetdict.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ end
2323
function Base.getproperty(d::DatasetDict, s::Symbol)
2424
if s in fieldnames(DatasetDict)
2525
return getfield(d, s)
26+
elseif s === :with_format
27+
return format -> with_format(d, format)
2628
else
2729
res = getproperty(getfield(d, :pyd), s)
2830
if pycallable(res)
@@ -44,6 +46,9 @@ function Base.deepcopy(d::DatasetDict)
4446
pyd = copy.deepcopy(d.pyd)
4547
return DatasetDict(pyd, d.jltransform)
4648
end
49+
50+
Base.show(io::IO, ds::DatasetDict) = print(io, ds.pyd)
51+
4752
""""
4853
with_jltransform(d::DatasetDict, transform)
4954
with_jltransform(transform, d::DatasetDict)

src/load_dataset.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11

22
"""
3-
load_dataset(args...; transform=py2jl, kws...)
3+
load_dataset(args...; kws...)
44
55
Load a dataset from the [HuggingFace Datasets](https://huggingface.co/datasets) library.
66
77
All arguments are passed to the python function `datasets.load_dataset`.
88
See the documentation [here](https://huggingface.co/docs/datasets/package_reference/loading_methods.html#datasets.load_dataset).
99
10+
Returns a [`DatasetDict`](@ref) or a [`Dataset`](@ref) depending on the `split` argument.
11+
12+
Use the `dataset.with_format("julia")` method to lazily convert the observation from the dataset
13+
to julia types.
14+
1015
# Examples
1116
17+
Without a `split` argument, a `DatasetDict` is returned:
18+
1219
```julia
1320
julia> d = load_dataset("glue", "sst2")
14-
DatasetDict(<py DatasetDict({
21+
DatasetDict({
1522
train: Dataset({
1623
features: ['sentence', 'label', 'idx'],
1724
num_rows: 67349
@@ -24,26 +31,29 @@ DatasetDict(<py DatasetDict({
2431
features: ['sentence', 'label', 'idx'],
2532
num_rows: 1821
2633
})
27-
})>, HuggingFaceDatasets.py2jl)
34+
})
2835
2936
julia> d["train"]
30-
Dataset(<py Dataset({
37+
Dataset({
3138
features: ['sentence', 'label', 'idx'],
3239
num_rows: 67349
33-
})>, HuggingFaceDatasets.py2jl)
40+
})
41+
```
3442
35-
mnist = load_dataset("mnist", split="train")
43+
Selecting a split returns a `Dataset` instead. We also
44+
apply the `"julia"` format.
3645
37-
julia> mnist = load_dataset("mnist", split="train")
38-
Dataset(<py Dataset({
46+
```julia
47+
julia> mnist = load_dataset("mnist", split="train").with_format("julia")
48+
Dataset({
3949
features: ['image', 'label'],
4050
num_rows: 60000
41-
})>, HuggingFaceDatasets.py2jl)
51+
})
4252
4353
julia> mnist[1]
4454
Dict{String, Any} with 2 entries:
4555
"label" => 5
46-
"image" => UInt8[0x00 0x000x00 0x00; 0x00 0x000x00 0x00; … ; 0x00 0x00 … 0x00 0x00; 0x00 0x00 … 0x00 0x00]
56+
"image" => Gray{N0f8}[Gray{N0f8}(0.0) Gray{N0f8}(0.0)Gray{N0f8}(0.0) Gray{N0f8}(0.0); Gray{N0f8}(0.0) Gray{N0f8}(0.0)Gray{N0f…
4757
```
4858
"""
4959
function load_dataset(args...; kws...)

src/observation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function MLUtils.getobs(py::Py, i::Integer)
55
elseif pyisinstance(py, pytype(pylist()))
66
# TODO do this only for lists containing numbers
77
return py[i-1]
8-
elseif pyisinstance(xpy, np.ndarray)
8+
elseif pyisinstance(py, np.ndarray)
99
return py[i-1]
1010
else
1111
return error("Py type $(pytype(py)) non supported yet")

0 commit comments

Comments
 (0)