@@ -394,6 +394,24 @@ def test_none_shape_bool(self, xp: ModuleType):
394394 a = a [a ]
395395 xp_assert_equal (isclose (a , b ), xp .asarray ([True , False ]))
396396
397+ @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
398+ @pytest .mark .skip_xp_backend (Backend .TORCH , reason = "Array API 2024.12 support" )
399+ def test_python_scalar (self , xp : ModuleType ):
400+ a = xp .asarray ([0.0 , 0.1 ], dtype = xp .float32 )
401+ xp_assert_equal (isclose (a , 0.0 ), xp .asarray ([True , False ]))
402+ xp_assert_equal (isclose (0.0 , a ), xp .asarray ([True , False ]))
403+
404+ a = xp .asarray ([0 , 1 ], dtype = xp .int16 )
405+ xp_assert_equal (isclose (a , 0 ), xp .asarray ([True , False ]))
406+ xp_assert_equal (isclose (0 , a ), xp .asarray ([True , False ]))
407+
408+ xp_assert_equal (isclose (0 , 0 , xp = xp ), xp .asarray (True ))
409+ xp_assert_equal (isclose (0 , 1 , xp = xp ), xp .asarray (False ))
410+
411+ def test_all_python_scalars (self ):
412+ with pytest .raises (TypeError , match = "Unrecognized" ):
413+ isclose (0 , 0 )
414+
397415 def test_xp (self , xp : ModuleType ):
398416 a = xp .asarray ([0.0 , 0.0 ])
399417 b = xp .asarray ([1e-9 , 1e-4 ])
@@ -406,30 +424,22 @@ def test_basic(self, xp: ModuleType):
406424 # Using 0-dimensional array
407425 a = xp .asarray (1 )
408426 b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
409- k = xp .asarray ([[1 , 2 ], [3 , 4 ]])
410- xp_assert_equal (kron (a , b ), k )
411- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
412- b = xp .asarray (1 )
413- xp_assert_equal (kron (a , b ), k )
427+ xp_assert_equal (kron (a , b ), b )
428+ xp_assert_equal (kron (b , a ), b )
414429
415430 # Using 1-dimensional array
416431 a = xp .asarray ([3 ])
417432 b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
418433 k = xp .asarray ([[3 , 6 ], [9 , 12 ]])
419434 xp_assert_equal (kron (a , b ), k )
420- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
421- b = xp .asarray ([3 ])
422- xp_assert_equal (kron (a , b ), k )
435+ xp_assert_equal (kron (b , a ), k )
423436
424437 # Using 3-dimensional array
425438 a = xp .asarray ([[[1 ]], [[2 ]]])
426439 b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
427440 k = xp .asarray ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
428441 xp_assert_equal (kron (a , b ), k )
429- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
430- b = xp .asarray ([[[1 ]], [[2 ]]])
431- k = xp .asarray ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
432- xp_assert_equal (kron (a , b ), k )
442+ xp_assert_equal (kron (b , a ), k )
433443
434444 def test_kron_smoke (self , xp : ModuleType ):
435445 a = xp .ones ((3 , 3 ))
@@ -467,6 +477,18 @@ def test_kron_shape(
467477 k = kron (a , b )
468478 assert k .shape == expected_shape
469479
480+ def test_python_scalar (self , xp : ModuleType ):
481+ a = 1
482+ # Test no dtype promotion to xp.asarray(a); use b.dtype
483+ b = xp .asarray ([[1 , 2 ], [3 , 4 ]], dtype = xp .int16 )
484+ xp_assert_equal (kron (a , b ), b )
485+ xp_assert_equal (kron (b , a ), b )
486+ xp_assert_equal (kron (1 , 1 , xp = xp ), xp .asarray (1 ))
487+
488+ def test_all_python_scalars (self ):
489+ with pytest .raises (TypeError , match = "Unrecognized" ):
490+ kron (1 , 1 )
491+
470492 def test_device (self , xp : ModuleType , device : Device ):
471493 x1 = xp .asarray ([1 , 2 , 3 ], device = device )
472494 x2 = xp .asarray ([4 , 5 ], device = device )
@@ -594,6 +616,28 @@ def test_shapes(
594616 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
595617 xp_assert_equal (actual , xp .empty ((0 ,)))
596618
619+ @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
620+ @pytest .mark .parametrize ("assume_unique" , [True , False ])
621+ def test_python_scalar (self , xp : ModuleType , assume_unique : bool ):
622+ # Test no dtype promotion to xp.asarray(x2); use x1.dtype
623+ x1 = xp .asarray ([3 , 1 , 2 ], dtype = xp .int16 )
624+ x2 = 3
625+ actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
626+ xp_assert_equal (actual , xp .asarray ([1 , 2 ], dtype = xp .int16 ))
627+
628+ actual = setdiff1d (x2 , x1 , assume_unique = assume_unique )
629+ xp_assert_equal (actual , xp .asarray ([], dtype = xp .int16 ))
630+
631+ xp_assert_equal (
632+ setdiff1d (0 , 0 , assume_unique = assume_unique , xp = xp ),
633+ xp .asarray ([0 ])[:0 ], # Default int dtype for backend
634+ )
635+
636+ @pytest .mark .parametrize ("assume_unique" , [True , False ])
637+ def test_all_python_scalars (self , assume_unique : bool ):
638+ with pytest .raises (TypeError , match = "Unrecognized" ):
639+ setdiff1d (0 , 0 , assume_unique = assume_unique )
640+
597641 def test_device (self , xp : ModuleType , device : Device ):
598642 x1 = xp .asarray ([3 , 8 , 20 ], device = device )
599643 x2 = xp .asarray ([2 , 3 , 4 ], device = device )
0 commit comments