@@ -469,15 +469,11 @@ def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
469469
470470 def test_basic (self , xp : ModuleType ):
471471 actual = one_hot (xp .asarray ([0 , 1 , 2 ]), 3 )
472- expected = xp .asarray ([[1. , 0. , 0. ],
473- [0. , 1. , 0. ],
474- [0. , 0. , 1. ]])
472+ expected = xp .asarray ([[1.0 , 0.0 , 0.0 ], [0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]])
475473 xp_assert_equal (actual , expected )
476474
477475 actual = one_hot (xp .asarray ([1 , 2 , 0 ]), 3 )
478- expected = xp .asarray ([[0. , 1. , 0. ],
479- [0. , 0. , 1. ],
480- [1. , 0. , 0. ]])
476+ expected = xp .asarray ([[0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 0.0 ]])
481477 xp_assert_equal (actual , expected )
482478
483479 @pytest .mark .skip_xp_backend (
@@ -489,32 +485,29 @@ def test_out_of_bound(self, xp: ModuleType):
489485 actual = one_hot (xp .asarray ([- 1 , 3 ]), 3 )
490486 except IndexError :
491487 return
492- expected = xp .asarray ([[0. , 0. , 0. ],
493- [0. , 0. , 0. ]])
488+ expected = xp .asarray ([[0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 ]])
494489 xp_assert_equal (actual , expected )
495490
496- @pytest .mark .parametrize ("int_dtype" , ['int8' , 'int16' , 'int32' , 'int64' , 'uint8' ,
497- 'uint16' , 'uint32' , 'uint64' ])
491+ @pytest .mark .parametrize (
492+ "int_dtype" ,
493+ ["int8" , "int16" , "int32" , "int64" , "uint8" , "uint16" , "uint32" , "uint64" ],
494+ )
498495 def test_int_types (self , xp : ModuleType , int_dtype : str ):
499496 dtype = getattr (xp , int_dtype )
500497 x = xp .asarray ([0 , 1 , 2 ], dtype = dtype )
501498 actual = one_hot (x , 3 )
502- expected = xp .asarray ([[1. , 0. , 0. ],
503- [0. , 1. , 0. ],
504- [0. , 0. , 1. ]])
499+ expected = xp .asarray ([[1.0 , 0.0 , 0.0 ], [0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]])
505500 xp_assert_equal (actual , expected )
506501
507502 def test_custom_dtype (self , xp : ModuleType ):
508503 actual = one_hot (xp .asarray ([0 , 1 , 2 ], dtype = xp .int32 ), 3 , dtype = xp .bool )
509- expected = xp .asarray ([[ True , False , False ],
510- [False , True , False ],
511- [ False , False , True ]] )
504+ expected = xp .asarray (
505+ [[ True , False , False ], [False , True , False ], [ False , False , True ]]
506+ )
512507 xp_assert_equal (actual , expected )
513508
514509 def test_axis (self , xp : ModuleType ):
515- expected = xp .asarray ([[0. , 1. , 0. ],
516- [0. , 0. , 1. ],
517- [1. , 0. , 0. ]]).T
510+ expected = xp .asarray ([[0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 0.0 ]]).T
518511 actual = one_hot (xp .asarray ([1 , 2 , 0 ]), 3 , axis = 0 )
519512 xp_assert_equal (actual , expected )
520513
0 commit comments