@@ -260,7 +260,7 @@ def test_array_unique_torch():
260260 result , indices = array_unique (arr , return_index = True )
261261 assert is_tensor (result )
262262 assert torch .equal (result , torch .tensor ([1 , 2 , 3 ]))
263- assert torch . equal (indices , torch . tensor ([0 , 1 , 3 ]))
263+ np . testing . assert_array_equal (indices , np . array ([0 , 1 , 3 ]))
264264
265265
266266def test_array_concatenate ():
@@ -467,18 +467,18 @@ def test_array_nonzero_torch():
467467 result = array_nonzero (array )
468468 assert isinstance (result , tuple )
469469 assert len (result ) == 1
470- assert is_tensor (result [0 ])
471- assert torch . equal (result [0 ], torch . tensor ([1 , 3 ]))
470+ assert is_numpy (result [0 ])
471+ np . testing . assert_array_equal (result [0 ], np . array ([1 , 3 ]))
472472
473473 # 2D tensor
474474 array_2d = torch .tensor ([[0 , 1 ], [2 , 0 ]])
475475 result = array_nonzero (array_2d )
476476 assert isinstance (result , tuple )
477477 assert len (result ) == 2
478- assert is_tensor (result [0 ])
479- assert is_tensor (result [1 ])
480- assert torch . equal (result [0 ], torch . tensor ([0 , 1 ]))
481- assert torch . equal (result [1 ], torch . tensor ([1 , 0 ]))
478+ assert is_numpy (result [0 ])
479+ assert is_numpy (result [1 ])
480+ np . testing . assert_array_equal (result [0 ], np . array ([0 , 1 ]))
481+ np . testing . assert_array_equal (result [1 ], np . array ([1 , 0 ]))
482482
483483
484484def test_stratified_split_indices ():
0 commit comments