@@ -44,6 +44,7 @@ namespace py = pybind11;
4444namespace type_utils = dpctl::tensor::type_utils;
4545
4646typedef sycl::event (*getrf_impl_fn_ptr_t )(sycl::queue &,
47+ const std::int64_t ,
4748 const std::int64_t ,
4849 char *,
4950 std::int64_t ,
@@ -56,6 +57,7 @@ static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types];
5657
5758template <typename T>
5859static sycl::event getrf_impl (sycl::queue &exec_q,
60+ const std::int64_t m,
5961 const std::int64_t n,
6062 char *in_a,
6163 std::int64_t lda,
@@ -69,7 +71,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
6971 T *a = reinterpret_cast <T *>(in_a);
7072
7173 const std::int64_t scratchpad_size =
72- mkl_lapack::getrf_scratchpad_size<T>(exec_q, n , n, lda);
74+ mkl_lapack::getrf_scratchpad_size<T>(exec_q, m , n, lda);
7375 T *scratchpad = nullptr ;
7476
7577 std::stringstream error_msg;
@@ -82,13 +84,13 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
8284
8385 getrf_event = mkl_lapack::getrf (
8486 exec_q,
85- n , // The order of the square matrix A (0 ≤ n ).
87+ m , // The number of rows in the input matrix A (0 ≤ m ).
8688 // It must be a non-negative integer.
87- n, // The number of columns in the square matrix A (0 ≤ n).
89+ n, // The number of columns in the input matrix A (0 ≤ n).
8890 // It must be a non-negative integer.
89- a, // Pointer to the square matrix A (n x n).
91+ a, // Pointer to the input matrix A (m x n).
9092 lda, // The leading dimension of matrix A.
91- // It must be at least max(1, n ).
93+ // It must be at least max(1, m ).
9294 ipiv, // Pointer to the output array of pivot indices.
9395 scratchpad, // Pointer to scratchpad memory to be used by MKL
9496 // routine for storing intermediate results.
@@ -99,7 +101,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
99101
100102 if (info < 0 ) {
101103 error_msg << " Parameter number " << -info
102- << " had an illegal value. " ;
104+ << " had an illegal value" ;
103105 }
104106 else if (info == scratchpad_size && e.detail () != 0 ) {
105107 error_msg
@@ -168,13 +170,13 @@ std::pair<sycl::event, sycl::event>
168170 if (a_array_nd != 2 ) {
169171 throw py::value_error (
170172 " The input array has ndim=" + std::to_string (a_array_nd) +
171- " , but a 2-dimensional array is expected. " );
173+ " , but a 2-dimensional array is expected" );
172174 }
173175
174176 if (ipiv_array_nd != 1 ) {
175177 throw py::value_error (" The array of pivot indices has ndim=" +
176178 std::to_string (ipiv_array_nd) +
177- " , but a 1-dimensional array is expected. " );
179+ " , but a 1-dimensional array is expected" );
178180 }
179181
180182 // check compatibility of execution queue and allocation queue
@@ -190,10 +192,11 @@ std::pair<sycl::event, sycl::event>
190192 }
191193
192194 bool is_a_array_c_contig = a_array.is_c_contiguous ();
195+ bool is_a_array_f_contig = a_array.is_f_contiguous ();
193196 bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous ();
194- if (!is_a_array_c_contig) {
197+ if (!is_a_array_c_contig && !is_a_array_f_contig ) {
195198 throw py::value_error (" The input array "
196- " must be C- contiguous" );
199+ " must be contiguous" );
197200 }
198201 if (!is_ipiv_array_c_contig) {
199202 throw py::value_error (" The array of pivot indices "
@@ -208,27 +211,33 @@ std::pair<sycl::event, sycl::event>
208211 if (getrf_fn == nullptr ) {
209212 throw py::value_error (
210213 " No getrf implementation defined for the provided type "
211- " of the input matrix. " );
214+ " of the input matrix" );
212215 }
213216
214217 auto ipiv_types = dpctl_td_ns::usm_ndarray_types ();
215218 int ipiv_array_type_id =
216219 ipiv_types.typenum_to_lookup_id (ipiv_array.get_typenum ());
217220
218221 if (ipiv_array_type_id != static_cast <int >(dpctl_td_ns::typenum_t ::INT64)) {
219- throw py::value_error (" The type of 'ipiv_array' must be int64. " );
222+ throw py::value_error (" The type of 'ipiv_array' must be int64" );
220223 }
221224
222- const std::int64_t n = a_array.get_shape_raw ()[0 ];
225+ const py::ssize_t *a_array_shape = a_array.get_shape_raw ();
226+ const std::int64_t m = a_array_shape[0 ];
227+ const std::int64_t n = a_array_shape[1 ];
228+ const std::int64_t lda = std::max<size_t >(1UL , m);
229+
230+ if (ipiv_array.get_size () != std::min (m, n)) {
231+ throw py::value_error (" The size of 'ipiv_array' must be min(m, n)" );
232+ }
223233
224234 char *a_array_data = a_array.get_data ();
225- const std::int64_t lda = std::max<size_t >(1UL , n);
226235
227236 char *ipiv_array_data = ipiv_array.get_data ();
228237 std::int64_t *d_ipiv = reinterpret_cast <std::int64_t *>(ipiv_array_data);
229238
230239 std::vector<sycl::event> host_task_events;
231- sycl::event getrf_ev = getrf_fn (exec_q, n, a_array_data, lda, d_ipiv,
240+ sycl::event getrf_ev = getrf_fn (exec_q, m, n, a_array_data, lda, d_ipiv,
232241 dev_info, host_task_events, depends);
233242
234243 sycl::event args_ev = dpctl::utils::keep_args_alive (
0 commit comments