2121import dpctl .utils as du
2222
2323from ._copy_utils import _empty_like_orderK
24- from ._scalar_utils import _get_dtype , _get_queue_usm_type , _validate_dtype
24+ from ._scalar_utils import (
25+ _get_dtype ,
26+ _get_queue_usm_type ,
27+ _get_shape ,
28+ _validate_dtype ,
29+ )
2530from ._tensor_elementwise_impl import _not_equal , _subtract
2631from ._tensor_impl import (
2732 _copy_usm_ndarray_into_usm_ndarray ,
3843 _searchsorted_left ,
3944 _sort_ascending ,
4045)
41- from ._type_utils import _resolve_weak_types_all_py_ints
46+ from ._type_utils import (
47+ _resolve_weak_types_all_py_ints ,
48+ _to_device_supported_dtype ,
49+ )
4250
4351__all__ = [
4452 "isin" ,
@@ -632,21 +640,17 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
632640 )
633641
634642
635- def isin (x , test_elements , / , * , assume_unique = False , invert = False ):
643+ def isin (x , test_elements , / , * , invert = False ):
636644 """
637645 Tests `x in test_elements` for each element of `x`. Returns a boolean array
638646 with the same shape as `x` that is `True` where the element is in
639647 `test_elements`, `False` otherwise.
640648
641649 Args:
642- x (usm_ndarray):
643- input array .
650+ x (Union[ usm_ndarray, bool, int, float, complex] ):
651+ input element or elements .
644652 test_elements (Union[usm_ndarray, bool, int, float, complex]):
645653 elements against which to test each value of `x`.
646- assume_unique (Optional[bool]):
647- if `True`, the input arrays are both assumed to be unique, which
648- currently has no effect.
649- Default: `False`.
650654 invert (Optional[bool]):
651655 if `True`, the output results are inverted, i.e., are equivalent to
652656 testing `x not in test_elements` for each element of `x`.
@@ -657,11 +661,19 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
657661 an array of the inclusion test results. The returned array has a
658662 boolean data type and the same shape as `x`.
659663 """
660- if not isinstance (x , dpt .usm_ndarray ):
661- raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
662- q1 , x_usm_type = x .sycl_queue , x .usm_type
664+ q1 , x_usm_type = _get_queue_usm_type (x )
663665 q2 , test_usm_type = _get_queue_usm_type (test_elements )
664- if q2 is None :
666+ if q1 is None and q2 is None :
667+ raise du .ExecutionPlacementError (
668+ "Execution placement can not be unambiguously inferred "
669+ "from input arguments. "
670+ "One of the arguments must represent USM allocation and "
671+ "expose `__sycl_usm_array_interface__` property"
672+ )
673+ if q1 is None :
674+ exec_q = q2
675+ res_usm_type = test_usm_type
676+ elif q2 is None :
665677 exec_q = q1
666678 res_usm_type = x_usm_type
667679 else :
@@ -680,45 +692,60 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
680692 dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
681693 sycl_dev = exec_q .sycl_device
682694
695+ x_dt = _get_dtype (x , sycl_dev )
696+ test_dt = _get_dtype (test_elements , sycl_dev )
697+ if not all (_validate_dtype (dt ) for dt in (x_dt , test_dt )):
698+ raise ValueError ("Operands have unsupported data types" )
699+
700+ x_sh = _get_shape (x )
683701 if isinstance (test_elements , dpt .usm_ndarray ) and test_elements .size == 0 :
684702 if invert :
685- return dpt .ones_like (x , dtype = dpt .bool , usm_type = res_usm_type )
703+ return dpt .ones (
704+ x_sh , dtype = dpt .bool , usm_type = res_usm_type , sycl_queue = exec_q
705+ )
686706 else :
687- return dpt .zeros_like (x , dtype = dpt .bool , usm_type = res_usm_type )
707+ return dpt .zeros (
708+ x_sh , dtype = dpt .bool , usm_type = res_usm_type , sycl_queue = exec_q
709+ )
688710
689- x_dt = x .dtype
690- test_dt = _get_dtype (test_elements , sycl_dev )
691- if not _validate_dtype (test_dt ):
692- raise ValueError ("`test_elements` has unsupported dtype" )
711+ dt1 , dt2 = _resolve_weak_types_all_py_ints (x_dt , test_dt , sycl_dev )
712+ dt = _to_device_supported_dtype (dpt .result_type (dt1 , dt2 ), sycl_dev )
713+
714+ if not isinstance (x , dpt .usm_ndarray ):
715+ x_arr = dpt .asarray (
716+ x , dtype = dt1 , usm_type = res_usm_type , sycl_queue = exec_q
717+ )
718+ else :
719+ x_arr = x
720+
721+ if not isinstance (test_elements , dpt .usm_ndarray ):
722+ test_arr = dpt .asarray (
723+ test_elements , dtype = dt2 , usm_type = res_usm_type , sycl_queue = exec_q
724+ )
725+ else :
726+ test_arr = test_elements
693727
694728 _manager = du .SequentialOrderManager [exec_q ]
695729 dep_evs = _manager .submitted_events
696730
697- dt1 , dt2 = _resolve_weak_types_all_py_ints (x_dt , test_dt , sycl_dev )
698- dt = dpt .result_type (dt1 , dt2 )
699-
700731 if x_dt != dt :
701- x_buf = _empty_like_orderK (x , dt )
732+ x_buf = _empty_like_orderK (x_arr , dt , res_usm_type , sycl_dev )
702733 ht_ev , ev = _copy_usm_ndarray_into_usm_ndarray (
703- src = x , dst = x_buf , sycl_queue = exec_q , depends = dep_evs
734+ src = x_arr , dst = x_buf , sycl_queue = exec_q , depends = dep_evs
704735 )
705736 _manager .add_event_pair (ht_ev , ev )
706737 else :
707- x_buf = x
738+ x_buf = x_arr
708739
709- if not isinstance (test_elements , dpt .usm_ndarray ):
710- test_buf = dpt .asarray (
711- test_elements , dtype = dt , usm_type = res_usm_type , sycl_queue = exec_q
712- )
713- elif test_dt != dt :
740+ if test_dt != dt :
714741 # copy into C-contiguous memory, because the array will be flattened
715- test_buf = dpt .empty_like (test_elements , dtype = dt , order = "C" )
742+ test_buf = dpt .empty_like (test_arr , dtype = dt , order = "C" )
716743 ht_ev , ev = _copy_usm_ndarray_into_usm_ndarray (
717- src = test_elements , dst = test_buf , sycl_queue = exec_q , depends = dep_evs
744+ src = test_arr , dst = test_buf , sycl_queue = exec_q , depends = dep_evs
718745 )
719746 _manager .add_event_pair (ht_ev , ev )
720747 else :
721- test_buf = test_elements
748+ test_buf = test_arr
722749
723750 test_buf = dpt .reshape (test_buf , - 1 )
724751 test_buf = dpt .sort (test_buf )
0 commit comments