77
88
99class TestSearch :
10+
1011 @testing .for_all_dtypes (no_complex = True )
1112 @testing .numpy_cupy_allclose ()
1213 def test_argmax_all (self , xp , dtype ):
@@ -167,6 +168,13 @@ def test_argmin_int32_overflow(self):
167168 assert a .argmin ().item () == 2 ** 32
168169
169170
171+ # TODO(leofang): remove this once CUDA 9.0 is dropped
172+ def _skip_cuda90 (dtype ):
173+ ver = cupy .cuda .runtime .runtimeGetVersion ()
174+ if dtype == cupy .float16 and ver == 9000 :
175+ pytest .skip ("CUB does not support fp16 on CUDA 9.0" )
176+
177+
170178# This class compares CUB results against NumPy's
171179# TODO(leofang): test axis after support is added
172180@testing .parameterize (
@@ -180,6 +188,7 @@ def test_argmin_int32_overflow(self):
180188)
181189@pytest .mark .skip ("The CUB routine is not enabled" )
182190class TestCubReduction :
191+
183192 @pytest .fixture (autouse = True )
184193 def setUp (self ):
185194 self .order , self .axis = self .order_and_axis
@@ -200,6 +209,7 @@ def setUp(self):
200209 @testing .for_dtypes ("bhilBHILefdFD" )
201210 @testing .numpy_cupy_allclose (rtol = 1e-5 , contiguous_check = False )
202211 def test_cub_argmin (self , xp , dtype ):
212+ _skip_cuda90 (dtype )
203213 a = testing .shaped_random (self .shape , xp , dtype )
204214 if self .order == "C" :
205215 a = xp .ascontiguousarray (a )
@@ -220,7 +230,7 @@ def test_cub_argmin(self, xp, dtype):
220230 # this is the only function we can mock; the rest is cdef'd
221231 func_name = "cupy._core._cub_reduction."
222232 func_name += "_SimpleCubReductionKernel_get_cached_function"
223- # func = _cub_reduction._SimpleCubReductionKernel_get_cached_function
233+ func = _cub_reduction ._SimpleCubReductionKernel_get_cached_function
224234 if self .axis is not None and len (self .shape ) > 1 :
225235 times_called = 1 # one pass
226236 else :
@@ -235,7 +245,7 @@ def test_cub_argmin(self, xp, dtype):
235245 @testing .for_dtypes ("bhilBHILefdFD" )
236246 @testing .numpy_cupy_allclose (rtol = 1e-5 , contiguous_check = False )
237247 def test_cub_argmax (self , xp , dtype ):
238- # _skip_cuda90(dtype)
248+ _skip_cuda90 (dtype )
239249 a = testing .shaped_random (self .shape , xp , dtype )
240250 if self .order == "C" :
241251 a = xp .ascontiguousarray (a )
@@ -256,7 +266,7 @@ def test_cub_argmax(self, xp, dtype):
256266 # this is the only function we can mock; the rest is cdef'd
257267 func_name = "cupy._core._cub_reduction."
258268 func_name += "_SimpleCubReductionKernel_get_cached_function"
259- # func = _cub_reduction._SimpleCubReductionKernel_get_cached_function
269+ func = _cub_reduction ._SimpleCubReductionKernel_get_cached_function
260270 if self .axis is not None and len (self .shape ) > 1 :
261271 times_called = 1 # one pass
262272 else :
@@ -280,6 +290,7 @@ def test_cub_argmax(self, xp, dtype):
280290)
281291@pytest .mark .skip ("dtype is not supported" )
282292class TestArgMinMaxDtype :
293+
283294 @testing .for_dtypes (
284295 dtypes = [numpy .int8 , numpy .int16 , numpy .int32 , numpy .int64 ],
285296 name = "result_dtype" ,
@@ -304,6 +315,7 @@ def test_argminmax_dtype(self, in_dtype, result_dtype):
304315 {"cond_shape" : (3 , 4 ), "x_shape" : (2 , 3 , 4 ), "y_shape" : (4 ,)},
305316)
306317class TestWhereTwoArrays :
318+
307319 @testing .for_all_dtypes_combination (names = ["cond_type" , "x_type" , "y_type" ])
308320 @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
309321 def test_where_two_arrays (self , xp , cond_type , x_type , y_type ):
@@ -323,6 +335,7 @@ def test_where_two_arrays(self, xp, cond_type, x_type, y_type):
323335 {"cond_shape" : (3 , 4 )},
324336)
325337class TestWhereCond :
338+
326339 @testing .for_all_dtypes ()
327340 @testing .numpy_cupy_array_equal ()
328341 def test_where_cond (self , xp , dtype ):
@@ -332,6 +345,7 @@ def test_where_cond(self, xp, dtype):
332345
333346
334347class TestWhereError :
348+
335349 def test_one_argument (self ):
336350 for xp in (numpy , cupy ):
337351 cond = testing .shaped_random ((3 , 4 ), xp , dtype = xp .bool_ )
@@ -349,6 +363,7 @@ def test_one_argument(self):
349363 _ids = False , # Do not generate ids from randomly generated params
350364)
351365class TestNonzero :
366+
352367 @testing .for_all_dtypes ()
353368 @testing .numpy_cupy_array_equal ()
354369 def test_nonzero (self , xp , dtype ):
@@ -360,15 +375,21 @@ def test_nonzero(self, xp, dtype):
360375 {"array" : numpy .array (0 )},
361376 {"array" : numpy .array (1 )},
362377)
363- @pytest .mark .skip ("Only positive rank is supported" )
364378@testing .with_requires ("numpy>=1.17.0" )
365379class TestNonzeroZeroDimension :
380+
381+ @testing .with_requires ("numpy>=2.1" )
382+ @testing .for_all_dtypes ()
383+ def test_nonzero (self , dtype ):
384+ array = cupy .array (self .array , dtype = dtype )
385+ with pytest .raises (ValueError ):
386+ cupy .nonzero (array )
387+
366388 @testing .for_all_dtypes ()
367389 @testing .numpy_cupy_array_equal ()
368- def test_nonzero (self , xp , dtype ):
390+ def test_nonzero_explicit (self , xp , dtype ):
369391 array = xp .array (self .array , dtype = dtype )
370- with testing .assert_warns (DeprecationWarning ):
371- return xp .nonzero (array )
392+ return xp .nonzero (xp .atleast_1d (array ))
372393
373394
374395@testing .parameterize (
@@ -382,6 +403,7 @@ def test_nonzero(self, xp, dtype):
382403 _ids = False , # Do not generate ids from randomly generated params
383404)
384405class TestFlatNonzero :
406+
385407 @testing .for_all_dtypes ()
386408 @testing .numpy_cupy_array_equal ()
387409 def test_flatnonzero (self , xp , dtype ):
@@ -398,6 +420,7 @@ def test_flatnonzero(self, xp, dtype):
398420 _ids = False , # Do not generate ids from randomly generated params
399421)
400422class TestArgwhere :
423+
401424 @testing .for_all_dtypes ()
402425 @testing .numpy_cupy_array_equal ()
403426 def test_argwhere (self , xp , dtype ):
@@ -411,6 +434,7 @@ def test_argwhere(self, xp, dtype):
411434)
412435@testing .with_requires ("numpy>=1.18" )
413436class TestArgwhereZeroDimension :
437+
414438 @testing .for_all_dtypes ()
415439 @testing .numpy_cupy_array_equal ()
416440 def test_argwhere (self , xp , dtype ):
@@ -419,6 +443,7 @@ def test_argwhere(self, xp, dtype):
419443
420444
421445class TestNanArgMin :
446+
422447 @testing .for_all_dtypes (no_complex = True )
423448 @testing .numpy_cupy_allclose ()
424449 def test_nanargmin_all (self , xp , dtype ):
@@ -509,6 +534,7 @@ def test_nanargmin_zero_size_axis1(self, xp, dtype):
509534
510535
511536class TestNanArgMax :
537+
512538 @testing .for_all_dtypes (no_complex = True )
513539 @testing .numpy_cupy_allclose ()
514540 def test_nanargmax_all (self , xp , dtype ):
@@ -620,6 +646,7 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype):
620646 )
621647)
622648class TestSearchSorted :
649+
623650 @testing .for_all_dtypes (no_bool = True )
624651 @testing .numpy_cupy_array_equal ()
625652 def test_searchsorted (self , xp , dtype ):
@@ -639,6 +666,7 @@ def test_ndarray_searchsorted(self, xp, dtype):
639666
640667@testing .parameterize ({"side" : "left" }, {"side" : "right" })
641668class TestSearchSortedNanInf :
669+
642670 @testing .numpy_cupy_array_equal ()
643671 def test_searchsorted_nanbins (self , xp ):
644672 x = testing .shaped_arange ((10 ,), xp , xp .float64 )
@@ -704,6 +732,7 @@ def test_searchsorted_minf(self, xp):
704732
705733
706734class TestSearchSortedInvalid :
735+
707736 # Can't test unordered bins due to numpy undefined
708737 # behavior for searchsorted
709738
@@ -723,6 +752,7 @@ def test_ndarray_searchsorted_ndbins(self):
723752
724753
725754class TestSearchSortedWithSorter :
755+
726756 @testing .numpy_cupy_array_equal ()
727757 def test_sorter (self , xp ):
728758 x = testing .shaped_arange ((12 ,), xp , xp .float64 )
@@ -741,16 +771,16 @@ def test_invalid_sorter(self):
741771
742772 def test_nonint_sorter (self ):
743773 for xp in (numpy , cupy ):
744- dt = cupy .default_float_type ()
745- x = testing .shaped_arange ((12 ,), xp , dt )
774+ x = testing .shaped_arange ((12 ,), xp , xp .float32 )
746775 bins = xp .array ([10 , 4 , 2 , 1 , 8 ])
747- sorter = xp .array ([], dtype = dt )
776+ sorter = xp .array ([], dtype = xp . float32 )
748777 with pytest .raises ((TypeError , ValueError )):
749778 xp .searchsorted (bins , x , sorter = sorter )
750779
751780
752781@testing .parameterize ({"side" : "left" }, {"side" : "right" })
753782class TestNdarraySearchSortedNanInf :
783+
754784 @testing .numpy_cupy_array_equal ()
755785 def test_searchsorted_nanbins (self , xp ):
756786 x = testing .shaped_arange ((10 ,), xp , xp .float64 )
@@ -816,6 +846,7 @@ def test_searchsorted_minf(self, xp):
816846
817847
818848class TestNdarraySearchSortedWithSorter :
849+
819850 @testing .numpy_cupy_array_equal ()
820851 def test_sorter (self , xp ):
821852 x = testing .shaped_arange ((12 ,), xp , xp .float64 )
@@ -834,9 +865,8 @@ def test_invalid_sorter(self):
834865
835866 def test_nonint_sorter (self ):
836867 for xp in (numpy , cupy ):
837- dt = cupy .default_float_type ()
838- x = testing .shaped_arange ((12 ,), xp , dt )
868+ x = testing .shaped_arange ((12 ,), xp , xp .float32 )
839869 bins = xp .array ([10 , 4 , 2 , 1 , 8 ])
840- sorter = xp .array ([], dtype = dt )
870+ sorter = xp .array ([], dtype = xp . float32 )
841871 with pytest .raises ((TypeError , ValueError )):
842872 bins .searchsorted (x , sorter = sorter )
0 commit comments