diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 3e8301df08b..5e3df491d4b 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -854,6 +854,9 @@ class ClassLabel: * `names`: List of label strings. * `names_file`: File containing the list of labels. + Under the hood the labels are stored as integers. + You can use negative integers to represent unknown/missing labels. + Args: num_classes (:obj:`int`, optional): Number of classes. All labels must be < `num_classes`. names (:obj:`list` of :obj:`str`, optional): String names for the integer classes. @@ -910,7 +913,7 @@ def __post_init__(self, names_file): def __call__(self): return self.pa_type - def str2int(self, values: Union[str, Iterable]): + def str2int(self, values: Union[str, Iterable]) -> Union[int, Iterable]: """Conversion class name string => integer. Example: @@ -934,7 +937,7 @@ def str2int(self, values: Union[str, Iterable]): output = [self._strval2int(value) for value in values] return output if return_list else output[0] - def _strval2int(self, value: str): + def _strval2int(self, value: str) -> int: failed_parse = False value = str(value) # first attempt - raw string value @@ -955,9 +958,11 @@ def _strval2int(self, value: str): raise ValueError(f"Invalid string class label {value}") return int_value - def int2str(self, values: Union[int, Iterable]): + def int2str(self, values: Union[int, Iterable]) -> Union[str, Iterable]: """Conversion integer => class name string. + Regarding unknown/missing labels: passing negative integers raises ValueError. + Example: ```py @@ -1014,8 +1019,6 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.In """ if isinstance(storage, pa.IntegerArray): min_max = pc.min_max(storage).as_py() - if min_max["min"] < -1: - raise ValueError(f"Class label {min_max['min']} less than -1") if min_max["max"] >= self.num_classes: raise ValueError( f"Class label {min_max['max']} greater than configured num_classes {self.num_classes}" diff --git a/tests/features/test_features.py b/tests/features/test_features.py index a800ba3419f..1868e067aa8 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -288,6 +288,8 @@ def test_classlabel_str2int(): classlabel.str2int("__bad_label_name__") with pytest.raises(ValueError): classlabel.str2int(1) + with pytest.raises(ValueError): + classlabel.str2int(None) def test_classlabel_int2str(): @@ -297,6 +299,45 @@ def test_classlabel_int2str(): assert classlabel.int2str(i) == names[i] with pytest.raises(ValueError): classlabel.int2str(len(names)) + with pytest.raises(ValueError): + classlabel.int2str(-1) + with pytest.raises(ValueError): + classlabel.int2str(None) + + +def test_classlabel_cast_storage(): + names = ["negative", "positive"] + classlabel = ClassLabel(names=names) + # from integers + arr = pa.array([0, 1, -1, -100], type=pa.int64()) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [0, 1, -1, -100] + arr = pa.array([0, 1, -1, -100], type=pa.int32()) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [0, 1, -1, -100] + arr = pa.array([3]) + with pytest.raises(ValueError): + classlabel.cast_storage(arr) + # from strings + arr = pa.array(["negative", "positive"]) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [0, 1] + arr = pa.array(["__label_that_doesnt_exist__"]) + with pytest.raises(ValueError): + classlabel.cast_storage(arr) + # from nulls + arr = pa.array([None]) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [None] + # from empty + arr = pa.array([]) + result = classlabel.cast_storage(arr) + assert result.type == pa.int64() + assert result.to_pylist() == [] @pytest.mark.parametrize("class_label_arg", ["names", "names_file"])