@@ -350,10 +350,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
350
350
int inner_nd = inner_dims;
351
351
const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims;
352
352
using shT = std::vector<py::ssize_t >;
353
- shT inner_x1_strides (std::begin (x1_strides_vec) + batch_dims,
354
- std::end (x1_strides_vec));
355
- shT inner_x2_strides (std::begin (x2_strides_vec) + batch_dims,
356
- std::end (x2_strides_vec));
353
+ const shT inner_x1_strides (std::begin (x1_strides_vec) + batch_dims,
354
+ std::end (x1_strides_vec));
355
+ const shT inner_x2_strides (std::begin (x2_strides_vec) + batch_dims,
356
+ std::end (x2_strides_vec));
357
357
358
358
shT simplified_inner_shape;
359
359
shT simplified_inner_x1_strides;
@@ -369,10 +369,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
369
369
370
370
const py::ssize_t *batch_shape_ptr = x1_shape_ptr;
371
371
372
- shT batch_x1_strides (std::begin (x1_strides_vec),
373
- std::begin (x1_strides_vec) + batch_dims);
374
- shT batch_x2_strides (std::begin (x2_strides_vec),
375
- std::begin (x2_strides_vec) + batch_dims);
372
+ const shT batch_x1_strides (std::begin (x1_strides_vec),
373
+ std::begin (x1_strides_vec) + batch_dims);
374
+ const shT batch_x2_strides (std::begin (x2_strides_vec),
375
+ std::begin (x2_strides_vec) + batch_dims);
376
376
shT const &batch_dst_strides = dst_strides_vec;
377
377
378
378
shT simplified_batch_shape;
@@ -551,9 +551,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
551
551
}
552
552
sycl::event copy_shapes_strides_ev =
553
553
std::get<2 >(ptr_size_event_tuple1);
554
- py::ssize_t *x1_shape_strides = packed_shapes_strides;
555
- py::ssize_t *x2_shape_strides = packed_shapes_strides + 2 * (x1_nd);
556
- py::ssize_t *dst_shape_strides =
554
+ const py::ssize_t *x1_shape_strides = packed_shapes_strides;
555
+ const py::ssize_t *x2_shape_strides =
556
+ packed_shapes_strides + 2 * (x1_nd);
557
+ const py::ssize_t *dst_shape_strides =
557
558
packed_shapes_strides + 2 * (x1_nd + x2_nd);
558
559
559
560
std::vector<sycl::event> all_deps;
@@ -619,29 +620,32 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
619
620
shT outer_inner_x1_strides;
620
621
dpctl::tensor::py_internal::split_iteration_space (
621
622
x1_shape_vec, x1_strides_vec, batch_dims,
622
- batch_dims + x1_outer_inner_dims, batch_x1_shape,
623
- outer_inner_x1_shape, // 4 vectors modified
624
- batch_x1_strides, outer_inner_x1_strides);
623
+ batch_dims + x1_outer_inner_dims,
624
+ // 4 vectors modified
625
+ batch_x1_shape, outer_inner_x1_shape, batch_x1_strides,
626
+ outer_inner_x1_strides);
625
627
626
628
shT batch_x2_shape;
627
629
shT outer_inner_x2_shape;
628
630
shT batch_x2_strides;
629
631
shT outer_inner_x2_strides;
630
632
dpctl::tensor::py_internal::split_iteration_space (
631
633
x2_shape_vec, x2_strides_vec, batch_dims,
632
- batch_dims + x2_outer_inner_dims, batch_x2_shape,
633
- outer_inner_x2_shape, // 4 vectors modified
634
- batch_x2_strides, outer_inner_x2_strides);
634
+ batch_dims + x2_outer_inner_dims,
635
+ // 4 vectors modified
636
+ batch_x2_shape, outer_inner_x2_shape, batch_x2_strides,
637
+ outer_inner_x2_strides);
635
638
636
639
shT batch_dst_shape;
637
640
shT outer_inner_dst_shape;
638
641
shT batch_dst_strides;
639
642
shT outer_inner_dst_strides;
640
643
dpctl::tensor::py_internal::split_iteration_space (
641
644
dst_shape_vec, dst_strides_vec, batch_dims,
642
- batch_dims + dst_outer_inner_dims, batch_dst_shape,
643
- outer_inner_dst_shape, // 4 vectors modified
644
- batch_dst_strides, outer_inner_dst_strides);
645
+ batch_dims + dst_outer_inner_dims,
646
+ // 4 vectors modified
647
+ batch_dst_shape, outer_inner_dst_shape, batch_dst_strides,
648
+ outer_inner_dst_strides);
645
649
646
650
using shT = std::vector<py::ssize_t >;
647
651
shT simplified_batch_shape;
@@ -746,16 +750,16 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
746
750
sycl::event copy_shapes_strides_ev =
747
751
std::get<2 >(ptr_size_event_tuple1);
748
752
749
- auto batch_shape_strides = packed_shapes_strides;
750
- auto x1_outer_inner_shapes_strides =
753
+ const auto batch_shape_strides = packed_shapes_strides;
754
+ const auto x1_outer_inner_shapes_strides =
751
755
packed_shapes_strides + 4 * batch_dims;
752
- auto x2_outer_inner_shapes_strides = packed_shapes_strides +
753
- 4 * batch_dims +
754
- 2 * (x1_outer_inner_dims);
755
- auto dst_outer_shapes_strides =
756
+ const auto x2_outer_inner_shapes_strides =
757
+ packed_shapes_strides + 4 * batch_dims +
758
+ 2 * (x1_outer_inner_dims);
759
+ const auto dst_outer_shapes_strides =
756
760
packed_shapes_strides + 4 * batch_dims +
757
761
2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims);
758
- auto dst_full_shape_strides =
762
+ const auto dst_full_shape_strides =
759
763
packed_shapes_strides + 4 * batch_dims +
760
764
2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims) +
761
765
2 * (dst_outer_inner_dims);
0 commit comments