@@ -1436,7 +1436,6 @@ <h1>Source code for dpctl.tensor._set_functions</h1><div class="highlight"><pre>
14361436< span class ="sd "> input array.</ span >
14371437< span class ="sd "> test_elements (Union[usm_ndarray, bool, int, float, complex]):</ span >
14381438< span class ="sd "> elements against which to test each value of `x`.</ span >
1439- < span class ="sd "> Default: `None`.</ span >
14401439< span class ="sd "> assume_unique (Optional[bool]):</ span >
14411440< span class ="sd "> if `True`, the input arrays are both assumed to be unique, which</ span >
14421441< span class ="sd "> currently has no effect.</ span >
@@ -1474,20 +1473,25 @@ <h1>Source code for dpctl.tensor._set_functions</h1><div class="highlight"><pre>
14741473 < span class ="n "> dpctl</ span > < span class ="o "> .</ span > < span class ="n "> utils</ span > < span class ="o "> .</ span > < span class ="n "> validate_usm_type</ span > < span class ="p "> (</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> allow_none</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span >
14751474 < span class ="n "> sycl_dev</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span > < span class ="o "> .</ span > < span class ="n "> sycl_device</ span >
14761475
1476+ < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> )</ span > < span class ="ow "> and</ span > < span class ="n "> test_elements</ span > < span class ="o "> .</ span > < span class ="n "> size</ span > < span class ="o "> ==</ span > < span class ="mi "> 0</ span > < span class ="p "> :</ span >
1477+ < span class ="k "> if</ span > < span class ="n "> invert</ span > < span class ="p "> :</ span >
1478+ < span class ="k "> return</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> ones_like</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> )</ span >
1479+ < span class ="k "> else</ span > < span class ="p "> :</ span >
1480+ < span class ="k "> return</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> zeros_like</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> )</ span >
1481+
14771482 < span class ="n "> x_dt</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span >
14781483 < span class ="n "> test_dt</ span > < span class ="o "> =</ span > < span class ="n "> _get_dtype</ span > < span class ="p "> (</ span > < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_dev</ span > < span class ="p "> )</ span >
14791484 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> _validate_dtype</ span > < span class ="p "> (</ span > < span class ="n "> test_dt</ span > < span class ="p "> ):</ span >
14801485 < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="s2 "> "`test_elements` has unsupported dtype"</ span > < span class ="p "> )</ span >
14811486
1482- < span class ="n "> dt</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> result_type</ span > < span class ="p "> (</ span >
1483- < span class ="o "> *</ span > < span class ="n "> _resolve_weak_types_all_py_ints</ span > < span class ="p "> (</ span > < span class ="n "> x_dt</ span > < span class ="p "> ,</ span > < span class ="n "> test_dt</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_dev</ span > < span class ="p "> )</ span >
1484- < span class ="p "> )</ span >
1485-
14861487 < span class ="n "> _manager</ span > < span class ="o "> =</ span > < span class ="n "> du</ span > < span class ="o "> .</ span > < span class ="n "> SequentialOrderManager</ span > < span class ="p "> [</ span > < span class ="n "> exec_q</ span > < span class ="p "> ]</ span >
1488+ < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
1489+
1490+ < span class ="n "> dt1</ span > < span class ="p "> ,</ span > < span class ="n "> dt2</ span > < span class ="o "> =</ span > < span class ="n "> _resolve_weak_types_all_py_ints</ span > < span class ="p "> (</ span > < span class ="n "> x_dt</ span > < span class ="p "> ,</ span > < span class ="n "> test_dt</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_dev</ span > < span class ="p "> )</ span >
1491+ < span class ="n "> dt</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> result_type</ span > < span class ="p "> (</ span > < span class ="n "> dt1</ span > < span class ="p "> ,</ span > < span class ="n "> dt2</ span > < span class ="p "> )</ span >
14871492
14881493 < span class ="k "> if</ span > < span class ="n "> x_dt</ span > < span class ="o "> !=</ span > < span class ="n "> dt</ span > < span class ="p "> :</ span >
14891494 < span class ="n "> x_buf</ span > < span class ="o "> =</ span > < span class ="n "> _empty_like_orderK</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> dt</ span > < span class ="p "> )</ span >
1490- < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
14911495 < span class ="n "> ht_ev</ span > < span class ="p "> ,</ span > < span class ="n "> ev</ span > < span class ="o "> =</ span > < span class ="n "> _copy_usm_ndarray_into_usm_ndarray</ span > < span class ="p "> (</ span >
14921496 < span class ="n "> src</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> dst</ span > < span class ="o "> =</ span > < span class ="n "> x_buf</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span > < span class ="p "> ,</ span > < span class ="n "> depends</ span > < span class ="o "> =</ span > < span class ="n "> dep_evs</ span >
14931497 < span class ="p "> )</ span >
@@ -1496,11 +1500,12 @@ <h1>Source code for dpctl.tensor._set_functions</h1><div class="highlight"><pre>
14961500 < span class ="n "> x_buf</ span > < span class ="o "> =</ span > < span class ="n "> x</ span >
14971501
14981502 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ):</ span >
1499- < span class ="n "> test_buf</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> asarray</ span > < span class ="p "> (</ span > < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dt</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span > < span class ="p "> )</ span >
1503+ < span class ="n "> test_buf</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> asarray</ span > < span class ="p "> (</ span >
1504+ < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dt</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span >
1505+ < span class ="p "> )</ span >
15001506 < span class ="k "> elif</ span > < span class ="n "> test_dt</ span > < span class ="o "> !=</ span > < span class ="n "> dt</ span > < span class ="p "> :</ span >
15011507 < span class ="c1 "> # copy into C-contiguous memory, because the array will be flattened</ span >
1502- < span class ="n "> test_buf</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> empty_like</ span > < span class ="p "> (</ span > < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="n "> dt</ span > < span class ="p "> ,</ span > < span class ="n "> order</ span > < span class ="o "> =</ span > < span class ="s2 "> "C"</ span > < span class ="p "> )</ span >
1503- < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
1508+ < span class ="n "> test_buf</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> empty_like</ span > < span class ="p "> (</ span > < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dt</ span > < span class ="p "> ,</ span > < span class ="n "> order</ span > < span class ="o "> =</ span > < span class ="s2 "> "C"</ span > < span class ="p "> )</ span >
15041509 < span class ="n "> ht_ev</ span > < span class ="p "> ,</ span > < span class ="n "> ev</ span > < span class ="o "> =</ span > < span class ="n "> _copy_usm_ndarray_into_usm_ndarray</ span > < span class ="p "> (</ span >
15051510 < span class ="n "> src</ span > < span class ="o "> =</ span > < span class ="n "> test_elements</ span > < span class ="p "> ,</ span > < span class ="n "> dst</ span > < span class ="o "> =</ span > < span class ="n "> test_buf</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span > < span class ="p "> ,</ span > < span class ="n "> depends</ span > < span class ="o "> =</ span > < span class ="n "> dep_evs</ span >
15061511 < span class ="p "> )</ span >
@@ -1511,7 +1516,9 @@ <h1>Source code for dpctl.tensor._set_functions</h1><div class="highlight"><pre>
15111516 < span class ="n "> test_buf</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> reshape</ span > < span class ="p "> (</ span > < span class ="n "> test_buf</ span > < span class ="p "> ,</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span >
15121517 < span class ="n "> test_buf</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> sort</ span > < span class ="p "> (</ span > < span class ="n "> test_buf</ span > < span class ="p "> )</ span >
15131518
1514- < span class ="n "> dst</ span > < span class ="o "> =</ span > < span class ="n "> _empty_like_orderK</ span > < span class ="p "> (</ span > < span class ="n "> x_buf</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> )</ span >
1519+ < span class ="n "> dst</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> empty_like</ span > < span class ="p "> (</ span >
1520+ < span class ="n "> x_buf</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> order</ span > < span class ="o "> =</ span > < span class ="s2 "> "C"</ span >
1521+ < span class ="p "> )</ span >
15151522
15161523 < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
15171524 < span class ="n "> ht_ev</ span > < span class ="p "> ,</ span > < span class ="n "> s_ev</ span > < span class ="o "> =</ span > < span class ="n "> _isin</ span > < span class ="p "> (</ span >
0 commit comments