Skip to content

Commit 32d9a3d

Browse files
cleanup
1 parent fa6b9cc commit 32d9a3d

File tree

6 files changed

+45
-68
lines changed

6 files changed

+45
-68
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ julia> length(train_data)
3939

4040
julia> train_data = load_dataset("mnist", split = "train").with_format("julia");
4141

42-
# Returned observations are now julia objects
43-
44-
julia> train_data[1]
42+
julia> train_data[1] # Returned observations are now julia objects
4543
Dict{String, Any} with 2 entries:
4644
"label" => 5
4745
"image" => Gray{N0f8}[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; ; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ makedocs(;
1515
),
1616
pages=[
1717
"Home" => "index.md",
18+
"API" => "api.md",
1819
],
1920
)
2021

docs/src/api.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
# API
2-
3-
## Index
4-
5-
```@index
6-
Pages = ["api.md"]
1+
```@meta
2+
CurrentModule = HuggingFaceDatasets
3+
CollapsedDocStrings = true
74
```
85

9-
## Docs
6+
# API
107

118
```@autodocs
129
Modules = [HuggingFaceDatasets]

docs/src/index.md

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,30 @@ HuggingFaceDatasets.jl provides wrappers around types from the `datasets` python
2626
Check out the `examples/` folder for usage examples.
2727

2828
```julia
29+
# Returned observations are now julia objects
30+
julia> using HuggingFaceDatasets
31+
2932
julia> train_data = load_dataset("mnist", split = "train")
30-
Dataset(<py Dataset({
33+
Dataset({
3134
features: ['image', 'label'],
3235
num_rows: 60000
33-
})>, identity)
36+
})
3437

35-
# Indexing starts with 1.
36-
# By defaul, python types are returned.
3738
julia> train_data[1]
38-
Python dict: {'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x2B64E2E90>, 'label': 5}
39+
Python: {'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x3340B0290>, 'label': 5}
3940

40-
julia> set_format!(train_data, "julia")
41-
Dataset(<py Dataset({
42-
features: ['image', 'label'],
43-
num_rows: 60000
44-
})>, HuggingFaceDatasets.py2jl)
41+
julia> length(train_data)
42+
60000
4543

46-
# Now we have julia types
47-
julia> train_data[1]
44+
julia> train_data = load_dataset("mnist", split = "train").with_format("julia");
45+
46+
julia> train_data[1] # Returned observations are now julia objects
4847
Dict{String, Any} with 2 entries:
4948
"label" => 5
50-
"image" => UInt8[0x00 0x00 0x00 0x00; 0x00 0x00 0x00 0x00; ; 0x00 0x00 0x00 0x00; 0x00 0x00 0x00 0x00]
49+
"image" => Gray{N0f8}[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; ; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
50+
51+
julia> train_data[1:2]
52+
Dict{String, Vector} with 2 entries:
53+
"label" => [5, 0]
54+
"image" => ReinterpretArray{Gray{N0f8}, 2, UInt8, Matrix{UInt8}, false}[[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; ; 0
5155
```

src/HuggingFaceDatasets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module HuggingFaceDatasets
33
using PythonCall
44
using MLUtils: getobs, numobs
55
import MLUtils
6-
using DLPack
6+
using DLPack: DLPack
77
using ImageCore
88

99
const datasets = PythonCall.pynew()

src/transforms.jl

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,11 @@
11

2-
# # See https://github.com/cjdoris/PythonCall.jl/issues/172.
3-
# function _pyconvert(x::Py)
4-
# @show x
5-
# if pyisinstance(x, datasets.Dataset)
6-
# return Dataset(x)
7-
# elseif pyisinstance(x, datasets.DatasetDict)
8-
# return DatasetDict(x)
9-
# elseif pyisinstance(x, PIL.PngImagePlugin.PngImageFile) || pyisinstance(x, PIL.JpegImagePlugin.JpegImageFile)
10-
# @show x
11-
# a = numpy2jl(np.array(x))
12-
# if ndims(a) == 3 && size(a, 1) == 3
13-
# return colorview(RGB{N0f8}, a)
14-
# elseif ndims(a) == 2
15-
# return reinterpret(Gray{N0f8}, a)
16-
# else
17-
# error("Unknown image format")
18-
# end
19-
# elseif pyisinstance(x, np.ndarray)
20-
# return numpy2jl(x)
21-
# else
22-
# return pyconvert(Any, x)
23-
# end
24-
# end
25-
26-
# # # Do nothing on a non-Py object.
27-
# # _pyconvert(x) = x
28-
292
"""
303
py2jl(x)
314
32-
Convert Python types to Julia types applying `pyconvert` recursively.
5+
Convert Python types to Julia types. It will recursively traverse built-in python
6+
containers such as lists, tuples, dicts, and sets, and convert all nested objects.
7+
On the leaves, it will call either `pyconvert(Any, x)` or [`numpy2jl`](@ref).
338
"""
34-
py2jl
35-
36-
# py2jl recurses through pycanonicalize and converts through _pyconvert
379
py2jl(x) = pyconvert(Any, x)
3810

3911
function py2jl(x::Py)
@@ -74,21 +46,26 @@ end
7446
"""
7547
numpy2jl(x)
7648
77-
Convert a numpy array to a Julia array using DLPack.
49+
Convert a numpy array to a Julia array using DLPack.jl.
7850
The conversion is copyless, and mutations to the Julia array are reflected in the numpy array.
51+
For row major python arrays, the returned Julia array has permuted dimensions.
52+
53+
This function is called by [`py2jl`](@ref).
54+
See also [`jl2numpy`](@ref).
7955
"""
8056
function numpy2jl(x::Py)
81-
# pyconvert(Any, x)
82-
# PyArray(x, copy=false)
83-
if Bool(x.dtype.type == np.str_)
84-
return PyArray(x, copy=false)
85-
else
86-
return DLPack.wrap(x, x -> x.__dlpack__())
87-
end
57+
return DLPack.from_dlpack(x)
8858
end
8959

90-
## TODO this doesn't work yet.
91-
## https://github.com/pabloferz/DLPack.jl/issues/32
92-
# function jl2numpy(x::AbstractArray)
93-
# return DLPack.share(x, np.from_dlpack)
94-
# end
60+
"""
61+
jl2numpy(x)
62+
63+
Convert a Julia array to a numpy array using DLPack.jl.
64+
The conversion is copyless, and mutations to the numpy array are reflected in the Julia array.
65+
The returned numpy array has permuted dimensions with respect to the input Julia array.
66+
67+
See also [`numpy2jl`](@ref).
68+
"""
69+
function jl2numpy(x::AbstractArray)
70+
return DLPack.share(x, np.from_dlpack)
71+
end

0 commit comments

Comments
 (0)