@@ -1123,7 +1123,7 @@ <h1>Source code for dpctl.tensor._manipulation_functions</h1><div class="highlig
11231123 < span class ="k "> return</ span > < span class ="n "> X</ span > < span class ="p "> [</ span > < span class ="n "> indexer</ span > < span class ="p "> ]</ span > </ div >
11241124
11251125
1126- < div class ="viewcode-block " id ="roll "> < a class ="viewcode-back " href ="../../../api_reference/dpctl/generated/dpctl.tensor.roll.html#dpctl.tensor.roll "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> roll</ span > < span class ="p "> (</ span > < span class ="n "> X </ span > < span class ="p "> ,</ span > < span class ="o "> /</ span > < span class ="p "> ,</ span > < span class ="n "> shift</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ):</ span >
1126+ < div class ="viewcode-block " id ="roll "> < a class ="viewcode-back " href ="../../../api_reference/dpctl/generated/dpctl.tensor.roll.html#dpctl.tensor.roll "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> roll</ span > < span class ="p "> (</ span > < span class ="n "> x </ span > < span class ="p "> ,</ span > < span class ="o "> /</ span > < span class ="p "> ,</ span > < span class ="n "> shift</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ):</ span >
11271127< span class ="w "> </ span > < span class ="sd "> """</ span >
11281128< span class ="sd "> roll(x, shift, axis)</ span >
11291129
@@ -1155,41 +1155,45 @@ <h1>Source code for dpctl.tensor._manipulation_functions</h1><div class="highlig
11551155< span class ="sd "> `device` attributes as `x` and whose elements are shifted relative</ span >
11561156< span class ="sd "> to `x`.</ span >
11571157< span class ="sd "> """</ span >
1158- < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> X </ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ):</ span >
1159- < span class ="k "> raise</ span > < span class ="ne "> TypeError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Expected usm_ndarray type, got </ span > < span class ="si "> {</ span > < span class ="nb "> type</ span > < span class ="p "> (</ span > < span class ="n "> X </ span > < span class ="p "> )</ span > < span class ="si "> }</ span > < span class ="s2 "> ."</ span > < span class ="p "> )</ span >
1160- < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> X </ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span >
1158+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> x </ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> usm_ndarray</ span > < span class ="p "> ):</ span >
1159+ < span class ="k "> raise</ span > < span class ="ne "> TypeError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Expected usm_ndarray type, got </ span > < span class ="si "> {</ span > < span class ="nb "> type</ span > < span class ="p "> (</ span > < span class ="n "> x </ span > < span class ="p "> )</ span > < span class ="si "> }</ span > < span class ="s2 "> ."</ span > < span class ="p "> )</ span >
1160+ < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> x </ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span >
11611161 < span class ="n "> _manager</ span > < span class ="o "> =</ span > < span class ="n "> dputils</ span > < span class ="o "> .</ span > < span class ="n "> SequentialOrderManager</ span > < span class ="p "> [</ span > < span class ="n "> exec_q</ span > < span class ="p "> ]</ span >
11621162 < span class ="k "> if</ span > < span class ="n "> axis</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
11631163 < span class ="n "> shift</ span > < span class ="o "> =</ span > < span class ="n "> operator</ span > < span class ="o "> .</ span > < span class ="n "> index</ span > < span class ="p "> (</ span > < span class ="n "> shift</ span > < span class ="p "> )</ span >
1164- < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
11651164 < span class ="n "> res</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> empty</ span > < span class ="p "> (</ span >
1166- < span class ="n "> X </ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> X </ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> X </ span > < span class ="o "> .</ span > < span class ="n "> usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span >
1165+ < span class ="n "> x </ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> x </ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> x </ span > < span class ="o "> .</ span > < span class ="n "> usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span >
11671166 < span class ="p "> )</ span >
1167+ < span class ="n "> sz</ span > < span class ="o "> =</ span > < span class ="n "> operator</ span > < span class ="o "> .</ span > < span class ="n "> index</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> size</ span > < span class ="p "> )</ span >
1168+ < span class ="n "> shift</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> shift</ span > < span class ="o "> %</ span > < span class ="n "> sz</ span > < span class ="p "> )</ span > < span class ="k "> if</ span > < span class ="n "> sz</ span > < span class ="o "> ></ span > < span class ="mi "> 0</ span > < span class ="k "> else</ span > < span class ="mi "> 0</ span >
1169+ < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
11681170 < span class ="n "> hev</ span > < span class ="p "> ,</ span > < span class ="n "> roll_ev</ span > < span class ="o "> =</ span > < span class ="n "> ti</ span > < span class ="o "> .</ span > < span class ="n "> _copy_usm_ndarray_for_roll_1d</ span > < span class ="p "> (</ span >
1169- < span class ="n "> src</ span > < span class ="o "> =</ span > < span class ="n "> X </ span > < span class ="p "> ,</ span >
1171+ < span class ="n "> src</ span > < span class ="o "> =</ span > < span class ="n "> x </ span > < span class ="p "> ,</ span >
11701172 < span class ="n "> dst</ span > < span class ="o "> =</ span > < span class ="n "> res</ span > < span class ="p "> ,</ span >
11711173 < span class ="n "> shift</ span > < span class ="o "> =</ span > < span class ="n "> shift</ span > < span class ="p "> ,</ span >
11721174 < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span > < span class ="p "> ,</ span >
11731175 < span class ="n "> depends</ span > < span class ="o "> =</ span > < span class ="n "> dep_evs</ span > < span class ="p "> ,</ span >
11741176 < span class ="p "> )</ span >
11751177 < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> add_event_pair</ span > < span class ="p "> (</ span > < span class ="n "> hev</ span > < span class ="p "> ,</ span > < span class ="n "> roll_ev</ span > < span class ="p "> )</ span >
11761178 < span class ="k "> return</ span > < span class ="n "> res</ span >
1177- < span class ="n "> axis</ span > < span class ="o "> =</ span > < span class ="n "> normalize_axis_tuple</ span > < span class ="p "> (</ span > < span class ="n "> axis</ span > < span class ="p "> ,</ span > < span class ="n "> X </ span > < span class ="o "> .</ span > < span class ="n "> ndim</ span > < span class ="p "> ,</ span > < span class ="n "> allow_duplicate</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
1179+ < span class ="n "> axis</ span > < span class ="o "> =</ span > < span class ="n "> normalize_axis_tuple</ span > < span class ="p "> (</ span > < span class ="n "> axis</ span > < span class ="p "> ,</ span > < span class ="n "> x </ span > < span class ="o "> .</ span > < span class ="n "> ndim</ span > < span class ="p "> ,</ span > < span class ="n "> allow_duplicate</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
11781180 < span class ="n "> broadcasted</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> broadcast</ span > < span class ="p "> (</ span > < span class ="n "> shift</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="p "> )</ span >
11791181 < span class ="k "> if</ span > < span class ="n "> broadcasted</ span > < span class ="o "> .</ span > < span class ="n "> ndim</ span > < span class ="o "> ></ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
11801182 < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="s2 "> "'shift' and 'axis' should be scalars or 1D sequences"</ span > < span class ="p "> )</ span >
11811183 < span class ="n "> shifts</ span > < span class ="o "> =</ span > < span class ="p "> [</ span >
11821184 < span class ="mi "> 0</ span > < span class ="p "> ,</ span >
1183- < span class ="p "> ]</ span > < span class ="o "> *</ span > < span class ="n "> X</ span > < span class ="o "> .</ span > < span class ="n "> ndim</ span >
1185+ < span class ="p "> ]</ span > < span class ="o "> *</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> ndim</ span >
1186+ < span class ="n "> shape</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span >
11841187 < span class ="k "> for</ span > < span class ="n "> sh</ span > < span class ="p "> ,</ span > < span class ="n "> ax</ span > < span class ="ow "> in</ span > < span class ="n "> broadcasted</ span > < span class ="p "> :</ span >
1185- < span class ="n "> shifts</ span > < span class ="p "> [</ span > < span class ="n "> ax</ span > < span class ="p "> ]</ span > < span class ="o "> +=</ span > < span class ="n "> sh</ span >
1186-
1188+ < span class ="n "> n_i</ span > < span class ="o "> =</ span > < span class ="n "> operator</ span > < span class ="o "> .</ span > < span class ="n "> index</ span > < span class ="p "> (</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="n "> ax</ span > < span class ="p "> ])</ span >
1189+ < span class ="n "> shifted</ span > < span class ="o "> =</ span > < span class ="n "> shifts</ span > < span class ="p "> [</ span > < span class ="n "> ax</ span > < span class ="p "> ]</ span > < span class ="o "> +</ span > < span class ="n "> operator</ span > < span class ="o "> .</ span > < span class ="n "> index</ span > < span class ="p "> (</ span > < span class ="n "> sh</ span > < span class ="p "> )</ span >
1190+ < span class ="n "> shifts</ span > < span class ="p "> [</ span > < span class ="n "> ax</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> shifted</ span > < span class ="o "> %</ span > < span class ="n "> n_i</ span > < span class ="p "> )</ span > < span class ="k "> if</ span > < span class ="n "> n_i</ span > < span class ="o "> ></ span > < span class ="mi "> 0</ span > < span class ="k "> else</ span > < span class ="mi "> 0</ span >
11871191 < span class ="n "> res</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> empty</ span > < span class ="p "> (</ span >
1188- < span class ="n "> X </ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> X </ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> X </ span > < span class ="o "> .</ span > < span class ="n "> usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span >
1192+ < span class ="n "> x </ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> x </ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="p "> ,</ span > < span class ="n "> usm_type</ span > < span class ="o "> =</ span > < span class ="n "> x </ span > < span class ="o "> .</ span > < span class ="n "> usm_type</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span >
11891193 < span class ="p "> )</ span >
11901194 < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
11911195 < span class ="n "> ht_e</ span > < span class ="p "> ,</ span > < span class ="n "> roll_ev</ span > < span class ="o "> =</ span > < span class ="n "> ti</ span > < span class ="o "> .</ span > < span class ="n "> _copy_usm_ndarray_for_roll_nd</ span > < span class ="p "> (</ span >
1192- < span class ="n "> src</ span > < span class ="o "> =</ span > < span class ="n "> X </ span > < span class ="p "> ,</ span > < span class ="n "> dst</ span > < span class ="o "> =</ span > < span class ="n "> res</ span > < span class ="p "> ,</ span > < span class ="n "> shifts</ span > < span class ="o "> =</ span > < span class ="n "> shifts</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span > < span class ="p "> ,</ span > < span class ="n "> depends</ span > < span class ="o "> =</ span > < span class ="n "> dep_evs</ span >
1196+ < span class ="n "> src</ span > < span class ="o "> =</ span > < span class ="n "> x </ span > < span class ="p "> ,</ span > < span class ="n "> dst</ span > < span class ="o "> =</ span > < span class ="n "> res</ span > < span class ="p "> ,</ span > < span class ="n "> shifts</ span > < span class ="o "> =</ span > < span class ="n "> shifts</ span > < span class ="p "> ,</ span > < span class ="n "> sycl_queue</ span > < span class ="o "> =</ span > < span class ="n "> exec_q</ span > < span class ="p "> ,</ span > < span class ="n "> depends</ span > < span class ="o "> =</ span > < span class ="n "> dep_evs</ span >
11931197 < span class ="p "> )</ span >
11941198 < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> add_event_pair</ span > < span class ="p "> (</ span > < span class ="n "> ht_e</ span > < span class ="p "> ,</ span > < span class ="n "> roll_ev</ span > < span class ="p "> )</ span >
11951199 < span class ="k "> return</ span > < span class ="n "> res</ span > </ div >
0 commit comments