Skip to content

Commit 83a8c85

Browse files
Apply remarks for getrf_batch.cpp
1 parent 6b3f331 commit 83a8c85

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

dpnp/backend/extensions/lapack/getrs_batch.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333
#include <pybind11/pybind11.h>
3434
#include <sycl/sycl.hpp>
3535

36+
// utils extension header
37+
#include "ext/common.hpp"
38+
3639
// dpctl tensor headers
3740
#include "utils/memory_overlap.hpp"
3841
#include "utils/sycl_alloc_utils.hpp"
39-
#include "utils/type_dispatch.hpp"
4042
#include "utils/type_utils.hpp"
4143

4244
#include "getrs.hpp"
@@ -48,7 +50,8 @@ namespace dpnp::extensions::lapack
4850
namespace mkl_lapack = oneapi::mkl::lapack;
4951
namespace py = pybind11;
5052
namespace type_utils = dpctl::tensor::type_utils;
51-
namespace td_ns = dpctl::tensor::type_dispatch;
53+
54+
using ext::common::init_dispatch_vector;
5255

5356
typedef sycl::event (*getrs_batch_impl_fn_ptr_t)(
5457
sycl::queue &,
@@ -67,7 +70,8 @@ typedef sycl::event (*getrs_batch_impl_fn_ptr_t)(
6770
std::vector<sycl::event> &,
6871
const std::vector<sycl::event> &);
6972

70-
static getrs_batch_impl_fn_ptr_t getrs_batch_dispatch_vector[td_ns::num_types];
73+
static getrs_batch_impl_fn_ptr_t
74+
getrs_batch_dispatch_vector[dpctl_td_ns::num_types];
7175

7276
template <typename T>
7377
static sycl::event getrs_batch_impl(sycl::queue &exec_q,
@@ -225,7 +229,7 @@ std::pair<sycl::event, sycl::event>
225229
std::to_string(b_array_nd) +
226230
", but an array with ndim >= 2 is expected");
227231
}
228-
if (ipiv_array_nd < 1) {
232+
if (ipiv_array_nd < 2) {
229233
throw py::value_error("The array of pivot indices has ndim=" +
230234
std::to_string(ipiv_array_nd) +
231235
", but an array with ndim >= 2 is expected");
@@ -280,7 +284,7 @@ std::pair<sycl::event, sycl::event>
280284
"must be contiguous");
281285
}
282286

283-
auto array_types = td_ns::usm_ndarray_types();
287+
auto array_types = dpctl_td_ns::usm_ndarray_types();
284288
int a_array_type_id =
285289
array_types.typenum_to_lookup_id(a_array.get_typenum());
286290
int b_array_type_id =
@@ -299,11 +303,11 @@ std::pair<sycl::event, sycl::event>
299303
"of the input matrix");
300304
}
301305

302-
auto ipiv_types = td_ns::usm_ndarray_types();
306+
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
303307
int ipiv_array_type_id =
304308
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
305309

306-
if (ipiv_array_type_id != static_cast<int>(td_ns::typenum_t::INT64)) {
310+
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
307311
throw py::value_error("The type of 'ipiv_array' must be int64");
308312
}
309313

@@ -343,9 +347,7 @@ struct GetrsBatchContigFactory
343347

344348
void init_getrs_batch_dispatch_vector(void)
345349
{
346-
td_ns::DispatchVectorBuilder<getrs_batch_impl_fn_ptr_t,
347-
GetrsBatchContigFactory, td_ns::num_types>
348-
contig;
349-
contig.populate_dispatch_vector(getrs_batch_dispatch_vector);
350+
init_dispatch_vector<getrs_batch_impl_fn_ptr_t, GetrsBatchContigFactory>(
351+
getrs_batch_dispatch_vector);
350352
}
351353
} // namespace dpnp::extensions::lapack

0 commit comments

Comments
 (0)