@@ -798,6 +798,7 @@ <h1>Source code for dpctl.tensor._copy_utils</h1><div class="highlight"><pre>
798798< span class ="c1 "> # limitations under the License.</ span >
799799< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> builtins</ span >
800800< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> operator</ span >
801+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> numbers</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> Integral</ span >
801802
802803< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> numpy</ span > < span class ="w "> </ span > < span class ="k "> as</ span > < span class ="w "> </ span > < span class ="nn "> np</ span >
803804
@@ -1602,15 +1603,26 @@ <h1>Source code for dpctl.tensor._copy_utils</h1><div class="highlight"><pre>
16021603 < span class ="p "> ]</ span >
16031604 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> inds</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="nb "> list</ span > < span class ="p "> ,</ span > < span class ="nb "> tuple</ span > < span class ="p "> )):</ span >
16041605 < span class ="n "> inds</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> inds</ span > < span class ="p "> ,)</ span >
1606+ < span class ="n "> any_usmarray</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span >
16051607 < span class ="k "> for</ span > < span class ="n "> ind</ span > < span class ="ow "> in</ span > < span class ="n "> inds</ span > < span class ="p "> :</ span >
1606- < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ):</ span >
1607- < span class ="k "> raise</ span > < span class ="ne "> TypeError</ span > < span class ="p "> (</ span > < span class ="s2 "> "all elements of `ind` expected to be usm_ndarrays"</ span > < span class ="p "> )</ span >
1608- < span class ="n "> queues_</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span > < span class ="p "> )</ span >
1609- < span class ="n "> usm_types_</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> usm_type</ span > < span class ="p "> )</ span >
1610- < span class ="k "> if</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> .</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="s2 "> "ui"</ span > < span class ="p "> :</ span >
1611- < span class ="k "> raise</ span > < span class ="ne "> IndexError</ span > < span class ="p "> (</ span >
1612- < span class ="s2 "> "arrays used as indices must be of integer (or boolean) type"</ span >
1608+ < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ):</ span >
1609+ < span class ="n "> any_usmarray</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span >
1610+ < span class ="k "> if</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> .</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="s2 "> "ui"</ span > < span class ="p "> :</ span >
1611+ < span class ="k "> raise</ span > < span class ="ne "> IndexError</ span > < span class ="p "> (</ span >
1612+ < span class ="s2 "> "arrays used as indices must be of integer (or boolean) "</ span >
1613+ < span class ="s2 "> "type"</ span >
1614+ < span class ="p "> )</ span >
1615+ < span class ="n "> queues_</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span > < span class ="p "> )</ span >
1616+ < span class ="n "> usm_types_</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> usm_type</ span > < span class ="p "> )</ span >
1617+ < span class ="k "> elif</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> Integral</ span > < span class ="p "> ):</ span >
1618+ < span class ="k "> raise</ span > < span class ="ne "> TypeError</ span > < span class ="p "> (</ span >
1619+ < span class ="s2 "> "all elements of `ind` expected to be usm_ndarrays "</ span >
1620+ < span class ="s2 "> "or integers"</ span >
16131621 < span class ="p "> )</ span >
1622+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> any_usmarray</ span > < span class ="p "> :</ span >
1623+ < span class ="k "> raise</ span > < span class ="ne "> TypeError</ span > < span class ="p "> (</ span >
1624+ < span class ="s2 "> "at least one element of `ind` expected to be a usm_ndarray"</ span >
1625+ < span class ="p "> )</ span >
16141626 < span class ="n "> res_usm_type</ span > < span class ="o "> =</ span > < span class ="n "> dpctl</ span > < span class ="o "> .</ span > < span class ="n "> utils</ span > < span class ="o "> .</ span > < span class ="n "> get_coerced_usm_type</ span > < span class ="p "> (</ span > < span class ="n "> usm_types_</ span > < span class ="p "> )</ span >
16151627 < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> dpctl</ span > < span class ="o "> .</ span > < span class ="n "> utils</ span > < span class ="o "> .</ span > < span class ="n "> get_execution_queue</ span > < span class ="p "> (</ span > < span class ="n "> queues_</ span > < span class ="p "> )</ span >
16161628 < span class ="k "> if</ span > < span class ="n "> exec_q</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
@@ -1621,6 +1633,18 @@ <h1>Source code for dpctl.tensor._copy_utils</h1><div class="highlight"><pre>
16211633 < span class ="s2 "> "be associated with the same queue."</ span >
16221634 < span class ="p "> )</ span >
16231635 < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> inds</ span > < span class ="p "> )</ span > < span class ="o "> ></ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
1636+ < span class ="n "> inds</ span > < span class ="o "> =</ span > < span class ="nb "> tuple</ span > < span class ="p "> (</ span >
1637+ < span class ="nb "> map</ span > < span class ="p "> (</ span >
1638+ < span class ="k "> lambda</ span > < span class ="n "> ind</ span > < span class ="p "> :</ span > < span class ="p "> (</ span >
1639+ < span class ="n "> ind</ span >
1640+ < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> )</ span >
1641+ < span class ="k "> else</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> asarray</ span > < span class ="p "> (</ span >
1642+ < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> res_usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span >
1643+ < span class ="p "> )</ span >
1644+ < span class ="p "> ),</ span >
1645+ < span class ="n "> inds</ span > < span class ="p "> ,</ span >
1646+ < span class ="p "> )</ span >
1647+ < span class ="p "> )</ span >
16241648 < span class ="n "> ind_dt</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> result_type</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> inds</ span > < span class ="p "> )</ span >
16251649 < span class ="c1 "> # ind arrays have been checked to be of integer dtype</ span >
16261650 < span class ="k "> if</ span > < span class ="n "> ind_dt</ span > < span class ="o "> .</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="s2 "> "ui"</ span > < span class ="p "> :</ span >
@@ -1751,15 +1775,26 @@ <h1>Source code for dpctl.tensor._copy_utils</h1><div class="highlight"><pre>
17511775 < span class ="p "> ]</ span >
17521776 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> inds</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="nb "> list</ span > < span class ="p "> ,</ span > < span class ="nb "> tuple</ span > < span class ="p "> )):</ span >
17531777 < span class ="n "> inds</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> inds</ span > < span class ="p "> ,)</ span >
1778+ < span class ="n "> any_usmarray</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span >
17541779 < span class ="k "> for</ span > < span class ="n "> ind</ span > < span class ="ow "> in</ span > < span class ="n "> inds</ span > < span class ="p "> :</ span >
1755- < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ):</ span >
1756- < span class ="k "> raise</ span > < span class ="ne "> TypeError</ span > < span class ="p "> (</ span > < span class ="s2 "> "all elements of `ind` expected to be usm_ndarrays"</ span > < span class ="p "> )</ span >
1757- < span class ="n "> queues_</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span > < span class ="p "> )</ span >
1758- < span class ="n "> usm_types_</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> usm_type</ span > < span class ="p "> )</ span >
1759- < span class ="k "> if</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> .</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="s2 "> "ui"</ span > < span class ="p "> :</ span >
1760- < span class ="k "> raise</ span > < span class ="ne "> IndexError</ span > < span class ="p "> (</ span >
1761- < span class ="s2 "> "arrays used as indices must be of integer (or boolean) type"</ span >
1780+ < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ):</ span >
1781+ < span class ="n "> any_usmarray</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span >
1782+ < span class ="k "> if</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> .</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="s2 "> "ui"</ span > < span class ="p "> :</ span >
1783+ < span class ="k "> raise</ span > < span class ="ne "> IndexError</ span > < span class ="p "> (</ span >
1784+ < span class ="s2 "> "arrays used as indices must be of integer (or boolean) "</ span >
1785+ < span class ="s2 "> "type"</ span >
1786+ < span class ="p "> )</ span >
1787+ < span class ="n "> queues_</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span > < span class ="p "> )</ span >
1788+ < span class ="n "> usm_types_</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="o "> .</ span > < span class ="n "> usm_type</ span > < span class ="p "> )</ span >
1789+ < span class ="k "> elif</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> Integral</ span > < span class ="p "> ):</ span >
1790+ < span class ="k "> raise</ span > < span class ="ne "> TypeError</ span > < span class ="p "> (</ span >
1791+ < span class ="s2 "> "all elements of `ind` expected to be usm_ndarrays "</ span >
1792+ < span class ="s2 "> "or integers"</ span >
17621793 < span class ="p "> )</ span >
1794+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> any_usmarray</ span > < span class ="p "> :</ span >
1795+ < span class ="k "> raise</ span > < span class ="ne "> TypeError</ span > < span class ="p "> (</ span >
1796+ < span class ="s2 "> "at least one element of `ind` expected to be a usm_ndarray"</ span >
1797+ < span class ="p "> )</ span >
17631798 < span class ="n "> vals_usm_type</ span > < span class ="o "> =</ span > < span class ="n "> dpctl</ span > < span class ="o "> .</ span > < span class ="n "> utils</ span > < span class ="o "> .</ span > < span class ="n "> get_coerced_usm_type</ span > < span class ="p "> (</ span > < span class ="n "> usm_types_</ span > < span class ="p "> )</ span >
17641799 < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> dpctl</ span > < span class ="o "> .</ span > < span class ="n "> utils</ span > < span class ="o "> .</ span > < span class ="n "> get_execution_queue</ span > < span class ="p "> (</ span > < span class ="n "> queues_</ span > < span class ="p "> )</ span >
17651800 < span class ="k "> if</ span > < span class ="n "> exec_q</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
@@ -1777,6 +1812,18 @@ <h1>Source code for dpctl.tensor._copy_utils</h1><div class="highlight"><pre>
17771812 < span class ="s2 "> "be associated with the same queue."</ span >
17781813 < span class ="p "> )</ span >
17791814 < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> inds</ span > < span class ="p "> )</ span > < span class ="o "> ></ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
1815+ < span class ="n "> inds</ span > < span class ="o "> =</ span > < span class ="nb "> tuple</ span > < span class ="p "> (</ span >
1816+ < span class ="nb "> map</ span > < span class ="p "> (</ span >
1817+ < span class ="k "> lambda</ span > < span class ="n "> ind</ span > < span class ="p "> :</ span > < span class ="p "> (</ span >
1818+ < span class ="n "> ind</ span >
1819+ < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> )</ span >
1820+ < span class ="k "> else</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> asarray</ span > < span class ="p "> (</ span >
1821+ < span class ="n "> ind</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> vals_usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span >
1822+ < span class ="p "> )</ span >
1823+ < span class ="p "> ),</ span >
1824+ < span class ="n "> inds</ span > < span class ="p "> ,</ span >
1825+ < span class ="p "> )</ span >
1826+ < span class ="p "> )</ span >
17801827 < span class ="n "> ind_dt</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> result_type</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> inds</ span > < span class ="p "> )</ span >
17811828 < span class ="c1 "> # ind arrays have been checked to be of integer dtype</ span >
17821829 < span class ="k "> if</ span > < span class ="n "> ind_dt</ span > < span class ="o "> .</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="s2 "> "ui"</ span > < span class ="p "> :</ span >
0 commit comments