3333#include < pybind11/pybind11.h>
3434#include < sycl/sycl.hpp>
3535
36+ // utils extension header
37+ #include " ext/common.hpp"
38+
3639// dpctl tensor headers
3740#include " utils/memory_overlap.hpp"
3841#include " utils/sycl_alloc_utils.hpp"
39- #include " utils/type_dispatch.hpp"
4042#include " utils/type_utils.hpp"
4143
4244#include " getrs.hpp"
@@ -48,7 +50,8 @@ namespace dpnp::extensions::lapack
4850namespace mkl_lapack = oneapi::mkl::lapack;
4951namespace py = pybind11;
5052namespace type_utils = dpctl::tensor::type_utils;
51- namespace td_ns = dpctl::tensor::type_dispatch;
53+
54+ using ext::common::init_dispatch_vector;
5255
5356typedef sycl::event (*getrs_batch_impl_fn_ptr_t )(
5457 sycl::queue &,
@@ -67,7 +70,8 @@ typedef sycl::event (*getrs_batch_impl_fn_ptr_t)(
6770 std::vector<sycl::event> &,
6871 const std::vector<sycl::event> &);
6972
70- static getrs_batch_impl_fn_ptr_t getrs_batch_dispatch_vector[td_ns::num_types];
73+ static getrs_batch_impl_fn_ptr_t
74+ getrs_batch_dispatch_vector[dpctl_td_ns::num_types];
7175
7276template <typename T>
7377static sycl::event getrs_batch_impl (sycl::queue &exec_q,
@@ -225,7 +229,7 @@ std::pair<sycl::event, sycl::event>
225229 std::to_string (b_array_nd) +
226230 " , but an array with ndim >= 2 is expected" );
227231 }
228- if (ipiv_array_nd < 1 ) {
232+ if (ipiv_array_nd < 2 ) {
229233 throw py::value_error (" The array of pivot indices has ndim=" +
230234 std::to_string (ipiv_array_nd) +
231235 " , but an array with ndim >= 2 is expected" );
@@ -280,7 +284,7 @@ std::pair<sycl::event, sycl::event>
280284 " must be contiguous" );
281285 }
282286
283- auto array_types = td_ns ::usm_ndarray_types ();
287+ auto array_types = dpctl_td_ns ::usm_ndarray_types ();
284288 int a_array_type_id =
285289 array_types.typenum_to_lookup_id (a_array.get_typenum ());
286290 int b_array_type_id =
@@ -299,11 +303,11 @@ std::pair<sycl::event, sycl::event>
299303 " of the input matrix" );
300304 }
301305
302- auto ipiv_types = td_ns ::usm_ndarray_types ();
306+ auto ipiv_types = dpctl_td_ns ::usm_ndarray_types ();
303307 int ipiv_array_type_id =
304308 ipiv_types.typenum_to_lookup_id (ipiv_array.get_typenum ());
305309
306- if (ipiv_array_type_id != static_cast <int >(td_ns ::typenum_t ::INT64)) {
310+ if (ipiv_array_type_id != static_cast <int >(dpctl_td_ns ::typenum_t ::INT64)) {
307311 throw py::value_error (" The type of 'ipiv_array' must be int64" );
308312 }
309313
@@ -343,9 +347,7 @@ struct GetrsBatchContigFactory
343347
344348void init_getrs_batch_dispatch_vector (void )
345349{
346- td_ns::DispatchVectorBuilder<getrs_batch_impl_fn_ptr_t ,
347- GetrsBatchContigFactory, td_ns::num_types>
348- contig;
349- contig.populate_dispatch_vector (getrs_batch_dispatch_vector);
350+ init_dispatch_vector<getrs_batch_impl_fn_ptr_t , GetrsBatchContigFactory>(
351+ getrs_batch_dispatch_vector);
350352}
351353} // namespace dpnp::extensions::lapack
0 commit comments