@@ -288,6 +288,8 @@ def test_classlabel_str2int():
288288 classlabel .str2int ("__bad_label_name__" )
289289 with pytest .raises (ValueError ):
290290 classlabel .str2int (1 )
291+ with pytest .raises (ValueError ):
292+ classlabel .str2int (None )
291293
292294
293295def test_classlabel_int2str ():
@@ -297,6 +299,45 @@ def test_classlabel_int2str():
297299 assert classlabel .int2str (i ) == names [i ]
298300 with pytest .raises (ValueError ):
299301 classlabel .int2str (len (names ))
302+ with pytest .raises (ValueError ):
303+ classlabel .int2str (- 1 )
304+ with pytest .raises (ValueError ):
305+ classlabel .int2str (None )
306+
307+
308+ def test_classlabel_cast_storage ():
309+ names = ["negative" , "positive" ]
310+ classlabel = ClassLabel (names = names )
311+ # from integers
312+ arr = pa .array ([0 , 1 , - 1 , - 100 ], type = pa .int64 ())
313+ result = classlabel .cast_storage (arr )
314+ assert result .type == pa .int64 ()
315+ assert result .to_pylist () == [0 , 1 , - 1 , - 100 ]
316+ arr = pa .array ([0 , 1 , - 1 , - 100 ], type = pa .int32 ())
317+ result = classlabel .cast_storage (arr )
318+ assert result .type == pa .int64 ()
319+ assert result .to_pylist () == [0 , 1 , - 1 , - 100 ]
320+ arr = pa .array ([3 ])
321+ with pytest .raises (ValueError ):
322+ classlabel .cast_storage (arr )
323+ # from strings
324+ arr = pa .array (["negative" , "positive" ])
325+ result = classlabel .cast_storage (arr )
326+ assert result .type == pa .int64 ()
327+ assert result .to_pylist () == [0 , 1 ]
328+ arr = pa .array (["__label_that_doesnt_exist__" ])
329+ with pytest .raises (ValueError ):
330+ classlabel .cast_storage (arr )
331+ # from nulls
332+ arr = pa .array ([None ])
333+ result = classlabel .cast_storage (arr )
334+ assert result .type == pa .int64 ()
335+ assert result .to_pylist () == [None ]
336+ # from empty
337+ arr = pa .array ([])
338+ result = classlabel .cast_storage (arr )
339+ assert result .type == pa .int64 ()
340+ assert result .to_pylist () == []
300341
301342
302343@pytest .mark .parametrize ("class_label_arg" , ["names" , "names_file" ])
0 commit comments