@@ -480,18 +480,32 @@ def test_basic(self, xp: ModuleType):
480480 [1. , 0. , 0. ]])
481481 xp_assert_equal (actual , expected )
482482
483+ @pytest .mark .skip_xp_backend (
484+ Backend .TORCH_GPU , reason = "Puts Pytorch into a bad state."
485+ )
483486 def test_out_of_bound (self , xp : ModuleType ):
484487 # Undefined behavior. Either return zero, or raise.
485488 try :
486489 actual = one_hot (xp .asarray ([- 1 , 3 ]), 3 )
487- except ( IndexError , RuntimeError ) :
490+ except IndexError :
488491 return
489492 expected = xp .asarray ([[0. , 0. , 0. ],
490493 [0. , 0. , 0. ]])
491494 xp_assert_equal (actual , expected )
492495
496+ @pytest .mark .parametrize ("int_dtype" , ['int8' , 'int16' , 'int32' , 'int64' , 'uint8' ,
497+ 'uint16' , 'uint32' , 'uint64' ])
498+ def test_int_types (self , xp : ModuleType , int_dtype : str ):
499+ dtype = getattr (xp , int_dtype )
500+ x = xp .asarray ([0 , 1 , 2 ], dtype = dtype )
501+ actual = one_hot (x , 3 )
502+ expected = xp .asarray ([[1. , 0. , 0. ],
503+ [0. , 1. , 0. ],
504+ [0. , 0. , 1. ]])
505+ xp_assert_equal (actual , expected )
506+
493507 def test_custom_dtype (self , xp : ModuleType ):
494- actual = one_hot (xp .asarray ([0 , 1 , 2 ]), 3 , dtype = xp .bool )
508+ actual = one_hot (xp .asarray ([0 , 1 , 2 ], dtype = xp . int32 ), 3 , dtype = xp .bool )
495509 expected = xp .asarray ([[True , False , False ],
496510 [False , True , False ],
497511 [False , False , True ]])
0 commit comments