Skip to content

Commit 6a3bd21

Browse files
fix datasetdict
1 parent 861a2f6 commit 6a3bd21

File tree

7 files changed

+12
-7
lines changed

7 files changed

+12
-7
lines changed

CondaPkg.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ channels = ["conda-forge"]
33
[deps]
44
# h5py = ""
55
# pillow = ">=9.1, <10"
6-
# numpy = ">=1.20, <2"
7-
pillow = ""
86
datasets = ">=2.12, <3"
7+
numpy = ">=1.20, <2"
8+
pillow = ""
99

1010
# pyarrow = "==6.0.0"

src/dataset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ version of [`with_format`](@ref).
107107
"""
108108
function set_format!(ds::Dataset, format)
109109
if format == "julia"
110-
# ds.pyds.set_format("numpy")
110+
ds.pyds.reset_format() # or d.pyd.set_format("python")
111111
ds.jltransform = py2jl
112112
else
113113
ds.pyds.set_format(format)

src/datasetdict.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ version of [`with_format`](@ref).
102102
"""
103103
function set_format!(d::DatasetDict, format)
104104
if format == "julia"
105-
d.pyd.set_format("numpy")
105+
d.pyd.reset_format()
106106
d.jltransform = py2jl
107107
else
108108
d.pyd.set_format(format)

src/transforms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ function _pyconvert(x::Py)
2121
end
2222
end
2323

24+
# Do nothing on a non-Py object.
2425
_pyconvert(x) = x
2526

2627
"""
@@ -30,6 +31,7 @@ Convert Python types to Julia types applying `pyconvert` recursively.
3031
"""
3132
py2jl
3233

34+
# py2jl recurses through pycanonicalize and converts through _pyconvert
3335
py2jl(x) = pycanonicalize(_pyconvert(x))
3436

3537
pycanonicalize(x) = x

test.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
using Test
2+
tmnist = load_dataset("mnist", split="test").with_format("julia")
3+
@test size(tmnist[1]["image"]) == (28, 28)
4+

test/dataset.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ end
7474
@test x isa Dict
7575
@test x["label"] == -1
7676
@test x["idx"] == 0
77-
@show x["premise"] |> typeof
7877
@test x["premise"] isa AbstractString
7978
@test x["premise"] == "The cat sat on the mat."
8079
@test x["hypothesis"] isa AbstractString

test/datasetdict.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ end
1919
@testset "with_format(julia)" begin
2020
d = with_format(mnist, "julia")
2121
ds = d["test"]
22-
@test ds.format["type"] == "numpy"
22+
@test ds.format["type"] == nothing
2323
x = ds[1]
2424
@test x isa Dict
2525
@test x["label"] isa Int
2626
@test x["label"] == 7
27-
@test x["image"] isa Matrix{UInt8}
27+
@test x["image"] isa AbstractMatrix{<:Gray}
2828
@test size(x["image"]) == (28, 28)
2929
end
3030

0 commit comments

Comments
 (0)