@@ -1303,7 +1303,11 @@ <h1>Source code for dpctl.tensor._indexing_functions</h1><div class="highlight">
13031303 < span class ="s2 "> "from input arguments. "</ span >
13041304 < span class ="p "> )</ span >
13051305 < span class ="n "> mode_i</ span > < span class ="o "> =</ span > < span class ="n "> _get_indexing_mode</ span > < span class ="p "> (</ span > < span class ="n "> mode</ span > < span class ="p "> )</ span >
1306- < span class ="n "> indexes_dt</ span > < span class ="o "> =</ span > < span class ="n "> ti</ span > < span class ="o "> .</ span > < span class ="n "> default_device_index_type</ span > < span class ="p "> (</ span > < span class ="n "> exec_q</ span > < span class ="o "> .</ span > < span class ="n "> sycl_device</ span > < span class ="p "> )</ span >
1306+ < span class ="n "> indexes_dt</ span > < span class ="o "> =</ span > < span class ="p "> (</ span >
1307+ < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> uint64</ span >
1308+ < span class ="k "> if</ span > < span class ="n "> indices</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> ==</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> uint64</ span >
1309+ < span class ="k "> else</ span > < span class ="n "> ti</ span > < span class ="o "> .</ span > < span class ="n "> default_device_index_type</ span > < span class ="p "> (</ span > < span class ="n "> exec_q</ span > < span class ="o "> .</ span > < span class ="n "> sycl_device</ span > < span class ="p "> )</ span >
1310+ < span class ="p "> )</ span >
13071311 < span class ="n "> _ind</ span > < span class ="o "> =</ span > < span class ="nb "> tuple</ span > < span class ="p "> (</ span >
13081312 < span class ="p "> (</ span >
13091313 < span class ="n "> indices</ span >
@@ -1379,7 +1383,11 @@ <h1>Source code for dpctl.tensor._indexing_functions</h1><div class="highlight">
13791383 < span class ="p "> )</ span >
13801384 < span class ="n "> out_usm_type</ span > < span class ="o "> =</ span > < span class ="n "> dpctl</ span > < span class ="o "> .</ span > < span class ="n "> utils</ span > < span class ="o "> .</ span > < span class ="n "> get_coerced_usm_type</ span > < span class ="p "> (</ span > < span class ="n "> usm_types_</ span > < span class ="p "> )</ span >
13811385 < span class ="n "> mode_i</ span > < span class ="o "> =</ span > < span class ="n "> _get_indexing_mode</ span > < span class ="p "> (</ span > < span class ="n "> mode</ span > < span class ="p "> )</ span >
1382- < span class ="n "> indexes_dt</ span > < span class ="o "> =</ span > < span class ="n "> ti</ span > < span class ="o "> .</ span > < span class ="n "> default_device_index_type</ span > < span class ="p "> (</ span > < span class ="n "> exec_q</ span > < span class ="o "> .</ span > < span class ="n "> sycl_device</ span > < span class ="p "> )</ span >
1386+ < span class ="n "> indexes_dt</ span > < span class ="o "> =</ span > < span class ="p "> (</ span >
1387+ < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> uint64</ span >
1388+ < span class ="k "> if</ span > < span class ="n "> indices</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> ==</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> uint64</ span >
1389+ < span class ="k "> else</ span > < span class ="n "> ti</ span > < span class ="o "> .</ span > < span class ="n "> default_device_index_type</ span > < span class ="p "> (</ span > < span class ="n "> exec_q</ span > < span class ="o "> .</ span > < span class ="n "> sycl_device</ span > < span class ="p "> )</ span >
1390+ < span class ="p "> )</ span >
13831391 < span class ="n "> _ind</ span > < span class ="o "> =</ span > < span class ="nb "> tuple</ span > < span class ="p "> (</ span >
13841392 < span class ="p "> (</ span >
13851393 < span class ="n "> indices</ span >
0 commit comments