Skip to content

Commit 69f17a3

Browse files
Annotate read-only vectors as const
1 parent 8757289 commit 69f17a3

File tree

1 file changed

+31
-27
lines changed
  • dpctl/tensor/libtensor/source/linalg_functions

1 file changed

+31
-27
lines changed

dpctl/tensor/libtensor/source/linalg_functions/dot.cpp

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
350350
int inner_nd = inner_dims;
351351
const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims;
352352
using shT = std::vector<py::ssize_t>;
353-
shT inner_x1_strides(std::begin(x1_strides_vec) + batch_dims,
354-
std::end(x1_strides_vec));
355-
shT inner_x2_strides(std::begin(x2_strides_vec) + batch_dims,
356-
std::end(x2_strides_vec));
353+
const shT inner_x1_strides(std::begin(x1_strides_vec) + batch_dims,
354+
std::end(x1_strides_vec));
355+
const shT inner_x2_strides(std::begin(x2_strides_vec) + batch_dims,
356+
std::end(x2_strides_vec));
357357

358358
shT simplified_inner_shape;
359359
shT simplified_inner_x1_strides;
@@ -369,10 +369,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
369369

370370
const py::ssize_t *batch_shape_ptr = x1_shape_ptr;
371371

372-
shT batch_x1_strides(std::begin(x1_strides_vec),
373-
std::begin(x1_strides_vec) + batch_dims);
374-
shT batch_x2_strides(std::begin(x2_strides_vec),
375-
std::begin(x2_strides_vec) + batch_dims);
372+
const shT batch_x1_strides(std::begin(x1_strides_vec),
373+
std::begin(x1_strides_vec) + batch_dims);
374+
const shT batch_x2_strides(std::begin(x2_strides_vec),
375+
std::begin(x2_strides_vec) + batch_dims);
376376
shT const &batch_dst_strides = dst_strides_vec;
377377

378378
shT simplified_batch_shape;
@@ -551,9 +551,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
551551
}
552552
sycl::event copy_shapes_strides_ev =
553553
std::get<2>(ptr_size_event_tuple1);
554-
py::ssize_t *x1_shape_strides = packed_shapes_strides;
555-
py::ssize_t *x2_shape_strides = packed_shapes_strides + 2 * (x1_nd);
556-
py::ssize_t *dst_shape_strides =
554+
const py::ssize_t *x1_shape_strides = packed_shapes_strides;
555+
const py::ssize_t *x2_shape_strides =
556+
packed_shapes_strides + 2 * (x1_nd);
557+
const py::ssize_t *dst_shape_strides =
557558
packed_shapes_strides + 2 * (x1_nd + x2_nd);
558559

559560
std::vector<sycl::event> all_deps;
@@ -619,29 +620,32 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
619620
shT outer_inner_x1_strides;
620621
dpctl::tensor::py_internal::split_iteration_space(
621622
x1_shape_vec, x1_strides_vec, batch_dims,
622-
batch_dims + x1_outer_inner_dims, batch_x1_shape,
623-
outer_inner_x1_shape, // 4 vectors modified
624-
batch_x1_strides, outer_inner_x1_strides);
623+
batch_dims + x1_outer_inner_dims,
624+
// 4 vectors modified
625+
batch_x1_shape, outer_inner_x1_shape, batch_x1_strides,
626+
outer_inner_x1_strides);
625627

626628
shT batch_x2_shape;
627629
shT outer_inner_x2_shape;
628630
shT batch_x2_strides;
629631
shT outer_inner_x2_strides;
630632
dpctl::tensor::py_internal::split_iteration_space(
631633
x2_shape_vec, x2_strides_vec, batch_dims,
632-
batch_dims + x2_outer_inner_dims, batch_x2_shape,
633-
outer_inner_x2_shape, // 4 vectors modified
634-
batch_x2_strides, outer_inner_x2_strides);
634+
batch_dims + x2_outer_inner_dims,
635+
// 4 vectors modified
636+
batch_x2_shape, outer_inner_x2_shape, batch_x2_strides,
637+
outer_inner_x2_strides);
635638

636639
shT batch_dst_shape;
637640
shT outer_inner_dst_shape;
638641
shT batch_dst_strides;
639642
shT outer_inner_dst_strides;
640643
dpctl::tensor::py_internal::split_iteration_space(
641644
dst_shape_vec, dst_strides_vec, batch_dims,
642-
batch_dims + dst_outer_inner_dims, batch_dst_shape,
643-
outer_inner_dst_shape, // 4 vectors modified
644-
batch_dst_strides, outer_inner_dst_strides);
645+
batch_dims + dst_outer_inner_dims,
646+
// 4 vectors modified
647+
batch_dst_shape, outer_inner_dst_shape, batch_dst_strides,
648+
outer_inner_dst_strides);
645649

646650
using shT = std::vector<py::ssize_t>;
647651
shT simplified_batch_shape;
@@ -746,16 +750,16 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
746750
sycl::event copy_shapes_strides_ev =
747751
std::get<2>(ptr_size_event_tuple1);
748752

749-
auto batch_shape_strides = packed_shapes_strides;
750-
auto x1_outer_inner_shapes_strides =
753+
const auto batch_shape_strides = packed_shapes_strides;
754+
const auto x1_outer_inner_shapes_strides =
751755
packed_shapes_strides + 4 * batch_dims;
752-
auto x2_outer_inner_shapes_strides = packed_shapes_strides +
753-
4 * batch_dims +
754-
2 * (x1_outer_inner_dims);
755-
auto dst_outer_shapes_strides =
756+
const auto x2_outer_inner_shapes_strides =
757+
packed_shapes_strides + 4 * batch_dims +
758+
2 * (x1_outer_inner_dims);
759+
const auto dst_outer_shapes_strides =
756760
packed_shapes_strides + 4 * batch_dims +
757761
2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims);
758-
auto dst_full_shape_strides =
762+
const auto dst_full_shape_strides =
759763
packed_shapes_strides + 4 * batch_dims +
760764
2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims) +
761765
2 * (dst_outer_inner_dims);

0 commit comments

Comments
 (0)