From 300da4a4bbeb9fa6a6c3ad441b92cffb92fd3515 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 16 Jun 2022 11:57:08 +0200 Subject: [PATCH 1/2] support all negative valeus in ClassLabel --- src/datasets/features/features.py | 17 ++++++++------- tests/features/test_features.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 3e8301df08b..6efaff65f8b 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -854,6 +854,9 @@ class ClassLabel: * `names`: List of label strings. * `names_file`: File containing the list of labels. + Under the hood the labels are stored as integers. + You can use negative integers to represent unknown/missing labels. + Args: num_classes (:obj:`int`, optional): Number of classes. All labels must be < `num_classes`. names (:obj:`list` of :obj:`str`, optional): String names for the integer classes. @@ -910,7 +913,7 @@ def __post_init__(self, names_file): def __call__(self): return self.pa_type - def str2int(self, values: Union[str, Iterable]): + def str2int(self, values: Union[str, Iterable]) -> Union[int, Iterable]: """Conversion class name string => integer. Example: @@ -934,7 +937,7 @@ def str2int(self, values: Union[str, Iterable]): output = [self._strval2int(value) for value in values] return output if return_list else output[0] - def _strval2int(self, value: str): + def _strval2int(self, value: str) -> int: failed_parse = False value = str(value) # first attempt - raw string value @@ -955,9 +958,11 @@ def _strval2int(self, value: str): raise ValueError(f"Invalid string class label {value}") return int_value - def int2str(self, values: Union[int, Iterable]): + def int2str(self, values: Union[int, Iterable]) -> Union[str, Iterable]: """Conversion integer => class name string. + Regarding unknown/missing labels: passing negative integers raises ValueError. + Example: ```py @@ -1014,16 +1019,12 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.In """ if isinstance(storage, pa.IntegerArray): min_max = pc.min_max(storage).as_py() - if min_max["min"] < -1: - raise ValueError(f"Class label {min_max['min']} less than -1") if min_max["max"] >= self.num_classes: raise ValueError( f"Class label {min_max['max']} greater than configured num_classes {self.num_classes}" ) elif isinstance(storage, pa.StringArray): - storage = pa.array( - [self._strval2int(label) if label is not None else None for label in storage.to_pylist()] - ) + storage = pa.array(self.str2int(storage.to_pylist())) return array_cast(storage, self.pa_type) @staticmethod diff --git a/tests/features/test_features.py b/tests/features/test_features.py index a800ba3419f..cb4423ebc44 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -288,6 +288,8 @@ def test_classlabel_str2int(): classlabel.str2int("__bad_label_name__") with pytest.raises(ValueError): classlabel.str2int(1) + with pytest.raises(ValueError): + classlabel.str2int(None) def test_classlabel_int2str(): @@ -297,6 +299,40 @@ def test_classlabel_int2str(): assert classlabel.int2str(i) == names[i] with pytest.raises(ValueError): classlabel.int2str(len(names)) + with pytest.raises(ValueError): + classlabel.int2str(-1) + with pytest.raises(ValueError): + classlabel.int2str(None) + + +def test_classlabel_cast_storage(): + names = ["negative", "positive"] + classlabel = ClassLabel(names=names) + # from integers + arr = pa.array([0, 1, -1, -100], type=pa.int64()) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [0, 1, -1, -100] + arr = pa.array([0, 1, -1, -100], type=pa.int32()) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [0, 1, -1, -100] + arr = pa.array([3]) + with pytest.raises(ValueError): + classlabel.cast_storage(arr) + # from strings + arr = pa.array(["negative", "positive"]) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [0, 1] + arr = pa.array(["__label_that_doesnt_exist__"]) + with pytest.raises(ValueError): + classlabel.cast_storage(arr) + # from empty + arr = pa.array([]) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [] @pytest.mark.parametrize("class_label_arg", ["names", "names_file"]) From cbd1303890784b62459da0efe97268e83f3f64a1 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 16 Jun 2022 14:16:39 +0200 Subject: [PATCH 2/2] support None in cast_storage --- src/datasets/features/features.py | 4 +++- tests/features/test_features.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 6efaff65f8b..5e3df491d4b 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1024,7 +1024,9 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.In f"Class label {min_max['max']} greater than configured num_classes {self.num_classes}" ) elif isinstance(storage, pa.StringArray): - storage = pa.array(self.str2int(storage.to_pylist())) + storage = pa.array( + [self._strval2int(label) if label is not None else None for label in storage.to_pylist()] + ) return array_cast(storage, self.pa_type) @staticmethod diff --git a/tests/features/test_features.py b/tests/features/test_features.py index cb4423ebc44..1868e067aa8 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -328,6 +328,11 @@ def test_classlabel_cast_storage(): arr = pa.array(["__label_that_doesnt_exist__"]) with pytest.raises(ValueError): classlabel.cast_storage(arr) + # from nulls + arr = pa.array([None]) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [None] # from empty arr = pa.array([]) result = classlabel.cast_storage(arr)