@@ -643,7 +643,6 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
643643 input array.
644644 test_elements (Union[usm_ndarray, bool, int, float, complex]):
645645 elements against which to test each value of `x`.
646- Default: `None`.
647646 assume_unique (Optional[bool]):
648647 if `True`, the input arrays are both assumed to be unique, which
649648 currently has no effect.
@@ -681,20 +680,25 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
681680 dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
682681 sycl_dev = exec_q .sycl_device
683682
683+ if isinstance (test_elements , dpt .usm_ndarray ) and test_elements .size == 0 :
684+ if invert :
685+ return dpt .ones_like (x , dtype = dpt .bool , usm_type = res_usm_type )
686+ else :
687+ return dpt .zeros_like (x , dtype = dpt .bool , usm_type = res_usm_type )
688+
684689 x_dt = x .dtype
685690 test_dt = _get_dtype (test_elements , sycl_dev )
686691 if not _validate_dtype (test_dt ):
687692 raise ValueError ("`test_elements` has unsupported dtype" )
688693
689- dt = dpt .result_type (
690- * _resolve_weak_types_all_py_ints (x_dt , test_dt , sycl_dev )
691- )
692-
693694 _manager = du .SequentialOrderManager [exec_q ]
695+ dep_evs = _manager .submitted_events
696+
697+ dt1 , dt2 = _resolve_weak_types_all_py_ints (x_dt , test_dt , sycl_dev )
698+ dt = dpt .result_type (dt1 , dt2 )
694699
695700 if x_dt != dt :
696701 x_buf = _empty_like_orderK (x , dt )
697- dep_evs = _manager .submitted_events
698702 ht_ev , ev = _copy_usm_ndarray_into_usm_ndarray (
699703 src = x , dst = x_buf , sycl_queue = exec_q , depends = dep_evs
700704 )
@@ -703,11 +707,12 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
703707 x_buf = x
704708
705709 if not isinstance (test_elements , dpt .usm_ndarray ):
706- test_buf = dpt .asarray (test_elements , dtype = dt , sycl_queue = exec_q )
710+ test_buf = dpt .asarray (
711+ test_elements , dtype = dt , usm_type = res_usm_type , sycl_queue = exec_q
712+ )
707713 elif test_dt != dt :
708714 # copy into C-contiguous memory, because the array will be flattened
709- test_buf = dpt .empty_like (test_elements , dt , order = "C" )
710- dep_evs = _manager .submitted_events
715+ test_buf = dpt .empty_like (test_elements , dtype = dt , order = "C" )
711716 ht_ev , ev = _copy_usm_ndarray_into_usm_ndarray (
712717 src = test_elements , dst = test_buf , sycl_queue = exec_q , depends = dep_evs
713718 )
@@ -718,7 +723,9 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
718723 test_buf = dpt .reshape (test_buf , - 1 )
719724 test_buf = dpt .sort (test_buf )
720725
721- dst = _empty_like_orderK (x_buf , dpt .bool , usm_type = res_usm_type )
726+ dst = dpt .empty_like (
727+ x_buf , dtype = dpt .bool , usm_type = res_usm_type , order = "C"
728+ )
722729
723730 dep_evs = _manager .submitted_events
724731 ht_ev , s_ev = _isin (
0 commit comments