@@ -854,6 +854,9 @@ class ClassLabel:
854854 * `names`: List of label strings.
855855 * `names_file`: File containing the list of labels.
856856
857+ Under the hood the labels are stored as integers.
858+ You can use negative integers to represent unknown/missing labels.
859+
857860 Args:
858861 num_classes (:obj:`int`, optional): Number of classes. All labels must be < `num_classes`.
859862 names (:obj:`list` of :obj:`str`, optional): String names for the integer classes.
@@ -910,7 +913,7 @@ def __post_init__(self, names_file):
910913 def __call__ (self ):
911914 return self .pa_type
912915
913- def str2int (self , values : Union [str , Iterable ]):
916+ def str2int (self , values : Union [str , Iterable ]) -> Union [ int , Iterable ] :
914917 """Conversion class name string => integer.
915918
916919 Example:
@@ -934,7 +937,7 @@ def str2int(self, values: Union[str, Iterable]):
934937 output = [self ._strval2int (value ) for value in values ]
935938 return output if return_list else output [0 ]
936939
937- def _strval2int (self , value : str ):
940+ def _strval2int (self , value : str ) -> int :
938941 failed_parse = False
939942 value = str (value )
940943 # first attempt - raw string value
@@ -955,9 +958,11 @@ def _strval2int(self, value: str):
955958 raise ValueError (f"Invalid string class label { value } " )
956959 return int_value
957960
958- def int2str (self , values : Union [int , Iterable ]):
961+ def int2str (self , values : Union [int , Iterable ]) -> Union [ str , Iterable ] :
959962 """Conversion integer => class name string.
960963
964+ Regarding unknown/missing labels: passing negative integers raises ValueError.
965+
961966 Example:
962967
963968 ```py
@@ -1014,16 +1019,12 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.In
10141019 """
10151020 if isinstance (storage , pa .IntegerArray ):
10161021 min_max = pc .min_max (storage ).as_py ()
1017- if min_max ["min" ] < - 1 :
1018- raise ValueError (f"Class label { min_max ['min' ]} less than -1" )
10191022 if min_max ["max" ] >= self .num_classes :
10201023 raise ValueError (
10211024 f"Class label { min_max ['max' ]} greater than configured num_classes { self .num_classes } "
10221025 )
10231026 elif isinstance (storage , pa .StringArray ):
1024- storage = pa .array (
1025- [self ._strval2int (label ) if label is not None else None for label in storage .to_pylist ()]
1026- )
1027+ storage = pa .array (self .str2int (storage .to_pylist ()))
10271028 return array_cast (storage , self .pa_type )
10281029
10291030 @staticmethod
0 commit comments