@@ -99,8 +99,6 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
99
99
std::shared_ptr<shT> packed_host_axes_shapes_strides_shp =
100
100
std::make_shared<shT>(2 * k + along_sh_elems, allocator);
101
101
102
- // can be made more efficient by checking if inp_nd > 1, then performing
103
- // same treatment of orthog_sh_elems as for 0D (orthog will not exist)
104
102
if (inp_nd > 0 ) {
105
103
std::copy (inp_shape, inp_shape + axis_start,
106
104
packed_host_shapes_strides_shp->begin ());
@@ -403,6 +401,17 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
403
401
}
404
402
}
405
403
404
+ // destination must be ample enough to accommodate all elements
405
+ {
406
+ size_t range =
407
+ static_cast <size_t >(dst_offsets.second - dst_offsets.first );
408
+ if ((range + 1 ) < (orthog_nelems * ind_nelems)) {
409
+ throw py::value_error (
410
+ " Destination array can not accommodate all the "
411
+ " elements of source array." );
412
+ }
413
+ }
414
+
406
415
auto ind_sh_elems = (ind_nd > 0 ) ? ind_nd : 1 ;
407
416
408
417
std::vector<char *> ind_ptrs;
@@ -580,17 +589,6 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
580
589
const py::ssize_t *src_strides = src.get_strides_raw ();
581
590
const py::ssize_t *dst_strides = dst.get_strides_raw ();
582
591
583
- // destination must be ample enough to accommodate all elements
584
- {
585
- size_t range =
586
- static_cast <size_t >(dst_offsets.second - dst_offsets.first );
587
- if ((range + 1 ) < (orthog_nelems * ind_nelems)) {
588
- throw py::value_error (
589
- " Destination array can not accommodate all the "
590
- " elements of source array." );
591
- }
592
- }
593
-
594
592
// packed_shapes_strides = [src_shape[:axis] + src_shape[axis+k:],
595
593
// src_strides[:axis] + src_strides[axis+k:],
596
594
// dst_strides[:axis] + dst_strides[axis+k:]]
@@ -765,6 +763,17 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
765
763
throw py::value_error (" Arrays index overlapping segments of memory" );
766
764
}
767
765
766
+ // destination must be ample enough to accommodate all possible elements
767
+ {
768
+ size_t range =
769
+ static_cast <size_t >(dst_offsets.second - dst_offsets.first );
770
+ if ((range + 1 ) < dst_nelems) {
771
+ throw py::value_error (
772
+ " Destination array can not accommodate all the "
773
+ " elements of source array." );
774
+ }
775
+ }
776
+
768
777
int dst_typenum = dst.get_typenum ();
769
778
int val_typenum = val.get_typenum ();
770
779
@@ -965,17 +974,6 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
965
974
const py::ssize_t *dst_strides = dst.get_strides_raw ();
966
975
const py::ssize_t *val_strides = val.get_strides_raw ();
967
976
968
- // destination must be ample enough to accommodate all possible elements
969
- {
970
- size_t range =
971
- static_cast <size_t >(dst_offsets.second - dst_offsets.first );
972
- if ((range + 1 ) < dst_nelems) {
973
- throw py::value_error (
974
- " Destination array can not accommodate all the "
975
- " elements of source array." );
976
- }
977
- }
978
-
979
977
// packed_shapes_strides = [dst_shape[:axis] + dst_shape[axis+k:],
980
978
// dst_strides[:axis] + dst_strides[axis+k:],
981
979
// val_strides[:axis] + val_strides[axis+k:]]
0 commit comments