Skip to content

Commit cbd1303

Browse files
committed
support None in cast_storage
1 parent 300da4a commit cbd1303

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/datasets/features/features.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,9 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.In
10241024
f"Class label {min_max['max']} greater than configured num_classes {self.num_classes}"
10251025
)
10261026
elif isinstance(storage, pa.StringArray):
1027-
storage = pa.array(self.str2int(storage.to_pylist()))
1027+
storage = pa.array(
1028+
[self._strval2int(label) if label is not None else None for label in storage.to_pylist()]
1029+
)
10281030
return array_cast(storage, self.pa_type)
10291031

10301032
@staticmethod

tests/features/test_features.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ def test_classlabel_cast_storage():
328328
arr = pa.array(["__label_that_doesnt_exist__"])
329329
with pytest.raises(ValueError):
330330
classlabel.cast_storage(arr)
331+
# from nulls
332+
arr = pa.array([None])
333+
result = classlabel.cast_storage(arr)
334+
assert result.type == pa.int64()
335+
assert result.to_pylist() == [None]
331336
# from empty
332337
arr = pa.array([])
333338
result = classlabel.cast_storage(arr)

0 commit comments

Comments
 (0)