@@ -807,9 +807,8 @@ <h1>Source code for dpctl.tensor._set_functions</h1><div class="highlight"><pre>
807807< span class ="c1 "> # See the License for the specific language governing permissions and</ span >
808808< span class ="c1 "> # limitations under the License.</ span >
809809
810- < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> typing</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> NamedTuple</ span >
810+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> typing</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> NamedTuple</ span > < span class =" p " > , </ span > < span class =" n " > Optional </ span > < span class =" p " > , </ span > < span class =" n " > Union </ span >
811811
812- < span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> dpctl</ span >
813812< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> dpctl.tensor</ span > < span class ="w "> </ span > < span class ="k "> as</ span > < span class ="w "> </ span > < span class ="nn "> dpt</ span >
814813< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> dpctl.utils</ span > < span class ="w "> </ span > < span class ="k "> as</ span > < span class ="w "> </ span > < span class ="nn "> du</ span >
815814
@@ -1433,7 +1432,13 @@ <h1>Source code for dpctl.tensor._set_functions</h1><div class="highlight"><pre>
14331432 < span class ="p "> )</ span > </ div >
14341433
14351434
1436- < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> isin</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="o "> /</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="p "> ,</ span > < span class ="n "> invert</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ):</ span >
1435+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> isin</ span > < span class ="p "> (</ span >
1436+ < span class ="n "> x</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ,</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span > < span class ="nb "> float</ span > < span class ="p "> ,</ span > < span class ="nb "> complex</ span > < span class ="p "> ,</ span > < span class ="nb "> bool</ span > < span class ="p "> ],</ span >
1437+ < span class ="n "> test_elements</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ,</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span > < span class ="nb "> float</ span > < span class ="p "> ,</ span > < span class ="nb "> complex</ span > < span class ="p "> ,</ span > < span class ="nb "> bool</ span > < span class ="p "> ],</ span >
1438+ < span class ="o "> /</ span > < span class ="p "> ,</ span >
1439+ < span class ="o "> *</ span > < span class ="p "> ,</ span >
1440+ < span class ="n "> invert</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="nb "> bool</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
1441+ < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> :</ span >
14371442< span class ="w "> </ span > < span class ="sd "> """</ span >
14381443< span class ="sd "> Tests `x in test_elements` for each element of `x`. Returns a boolean array</ span >
14391444< span class ="sd "> with the same shape as `x` that is `True` where the element is in</ span >
@@ -1470,19 +1475,19 @@ <h1>Source code for dpctl.tensor._set_functions</h1><div class="highlight"><pre>
14701475 < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> q1</ span >
14711476 < span class ="n "> res_usm_type</ span > < span class ="o "> =</ span > < span class ="n "> x_usm_type</ span >
14721477 < span class ="k "> else</ span > < span class ="p "> :</ span >
1473- < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> dpctl </ span > < span class =" o " > . </ span > < span class =" n " > utils </ span > < span class ="o "> .</ span > < span class ="n "> get_execution_queue</ span > < span class ="p "> ((</ span > < span class ="n "> q1</ span > < span class ="p "> ,</ span > < span class ="n "> q2</ span > < span class ="p "> ))</ span >
1478+ < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> du </ span > < span class ="o "> .</ span > < span class ="n "> get_execution_queue</ span > < span class ="p "> ((</ span > < span class ="n "> q1</ span > < span class ="p "> ,</ span > < span class ="n "> q2</ span > < span class ="p "> ))</ span >
14741479 < span class ="k "> if</ span > < span class ="n "> exec_q</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
14751480 < span class ="k "> raise</ span > < span class ="n "> du</ span > < span class ="o "> .</ span > < span class ="n "> ExecutionPlacementError</ span > < span class ="p "> (</ span >
14761481 < span class ="s2 "> "Execution placement can not be unambiguously inferred "</ span >
14771482 < span class ="s2 "> "from input arguments."</ span >
14781483 < span class ="p "> )</ span >
1479- < span class ="n "> res_usm_type</ span > < span class ="o "> =</ span > < span class ="n "> dpctl </ span > < span class =" o " > . </ span > < span class =" n " > utils </ span > < span class ="o "> .</ span > < span class ="n "> get_coerced_usm_type</ span > < span class ="p "> (</ span >
1484+ < span class ="n "> res_usm_type</ span > < span class ="o "> =</ span > < span class ="n "> du </ span > < span class ="o "> .</ span > < span class ="n "> get_coerced_usm_type</ span > < span class ="p "> (</ span >
14801485 < span class ="p "> (</ span >
14811486 < span class ="n "> x_usm_type</ span > < span class ="p "> ,</ span >
14821487 < span class ="n "> test_usm_type</ span > < span class ="p "> ,</ span >
14831488 < span class ="p "> )</ span >
14841489 < span class ="p "> )</ span >
1485- < span class ="n "> dpctl </ span > < span class =" o " > . </ span > < span class =" n " > utils </ span > < span class ="o "> .</ span > < span class ="n "> validate_usm_type</ span > < span class ="p "> (</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> allow_none</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span >
1490+ < span class ="n "> du </ span > < span class ="o "> .</ span > < span class ="n "> validate_usm_type</ span > < span class ="p "> (</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> allow_none</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span >
14861491 < span class ="n "> sycl_dev</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span > < span class ="o "> .</ span > < span class ="n "> sycl_device</ span >
14871492
14881493 < span class ="n "> x_dt</ span > < span class ="o "> =</ span > < span class ="n "> _get_dtype</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_dev</ span > < span class ="p "> )</ span >
0 commit comments