@@ -44,6 +44,7 @@ namespace py = pybind11;
44
44
namespace type_utils = dpctl::tensor::type_utils;
45
45
46
46
typedef sycl::event (*getrf_impl_fn_ptr_t )(sycl::queue &,
47
+ const std::int64_t ,
47
48
const std::int64_t ,
48
49
char *,
49
50
std::int64_t ,
@@ -56,6 +57,7 @@ static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types];
56
57
57
58
template <typename T>
58
59
static sycl::event getrf_impl (sycl::queue &exec_q,
60
+ const std::int64_t m,
59
61
const std::int64_t n,
60
62
char *in_a,
61
63
std::int64_t lda,
@@ -69,7 +71,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
69
71
T *a = reinterpret_cast <T *>(in_a);
70
72
71
73
const std::int64_t scratchpad_size =
72
- mkl_lapack::getrf_scratchpad_size<T>(exec_q, n , n, lda);
74
+ mkl_lapack::getrf_scratchpad_size<T>(exec_q, m , n, lda);
73
75
T *scratchpad = nullptr ;
74
76
75
77
std::stringstream error_msg;
@@ -82,13 +84,13 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
82
84
83
85
getrf_event = mkl_lapack::getrf (
84
86
exec_q,
85
- n , // The order of the square matrix A (0 ≤ n ).
87
+ m , // The number of rows in the input matrix A (0 ≤ m ).
86
88
// It must be a non-negative integer.
87
- n, // The number of columns in the square matrix A (0 ≤ n).
89
+ n, // The number of columns in the input matrix A (0 ≤ n).
88
90
// It must be a non-negative integer.
89
- a, // Pointer to the square matrix A (n x n).
91
+ a, // Pointer to the input matrix A (m x n).
90
92
lda, // The leading dimension of matrix A.
91
- // It must be at least max(1, n ).
93
+ // It must be at least max(1, m ).
92
94
ipiv, // Pointer to the output array of pivot indices.
93
95
scratchpad, // Pointer to scratchpad memory to be used by MKL
94
96
// routine for storing intermediate results.
@@ -99,7 +101,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
99
101
100
102
if (info < 0 ) {
101
103
error_msg << " Parameter number " << -info
102
- << " had an illegal value. " ;
104
+ << " had an illegal value" ;
103
105
}
104
106
else if (info == scratchpad_size && e.detail () != 0 ) {
105
107
error_msg
@@ -168,13 +170,13 @@ std::pair<sycl::event, sycl::event>
168
170
if (a_array_nd != 2 ) {
169
171
throw py::value_error (
170
172
" The input array has ndim=" + std::to_string (a_array_nd) +
171
- " , but a 2-dimensional array is expected. " );
173
+ " , but a 2-dimensional array is expected" );
172
174
}
173
175
174
176
if (ipiv_array_nd != 1 ) {
175
177
throw py::value_error (" The array of pivot indices has ndim=" +
176
178
std::to_string (ipiv_array_nd) +
177
- " , but a 1-dimensional array is expected. " );
179
+ " , but a 1-dimensional array is expected" );
178
180
}
179
181
180
182
// check compatibility of execution queue and allocation queue
@@ -190,10 +192,11 @@ std::pair<sycl::event, sycl::event>
190
192
}
191
193
192
194
bool is_a_array_c_contig = a_array.is_c_contiguous ();
195
+ bool is_a_array_f_contig = a_array.is_f_contiguous ();
193
196
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous ();
194
- if (!is_a_array_c_contig) {
197
+ if (!is_a_array_c_contig && !is_a_array_f_contig ) {
195
198
throw py::value_error (" The input array "
196
- " must be C- contiguous" );
199
+ " must be contiguous" );
197
200
}
198
201
if (!is_ipiv_array_c_contig) {
199
202
throw py::value_error (" The array of pivot indices "
@@ -208,27 +211,33 @@ std::pair<sycl::event, sycl::event>
208
211
if (getrf_fn == nullptr ) {
209
212
throw py::value_error (
210
213
" No getrf implementation defined for the provided type "
211
- " of the input matrix. " );
214
+ " of the input matrix" );
212
215
}
213
216
214
217
auto ipiv_types = dpctl_td_ns::usm_ndarray_types ();
215
218
int ipiv_array_type_id =
216
219
ipiv_types.typenum_to_lookup_id (ipiv_array.get_typenum ());
217
220
218
221
if (ipiv_array_type_id != static_cast <int >(dpctl_td_ns::typenum_t ::INT64)) {
219
- throw py::value_error (" The type of 'ipiv_array' must be int64. " );
222
+ throw py::value_error (" The type of 'ipiv_array' must be int64" );
220
223
}
221
224
222
- const std::int64_t n = a_array.get_shape_raw ()[0 ];
225
+ const py::ssize_t *a_array_shape = a_array.get_shape_raw ();
226
+ const std::int64_t m = a_array_shape[0 ];
227
+ const std::int64_t n = a_array_shape[1 ];
228
+ const std::int64_t lda = std::max<size_t >(1UL , m);
229
+
230
+ if (ipiv_array.get_size () != std::min (m, n)) {
231
+ throw py::value_error (" The size of 'ipiv_array' must be min(m, n)" );
232
+ }
223
233
224
234
char *a_array_data = a_array.get_data ();
225
- const std::int64_t lda = std::max<size_t >(1UL , n);
226
235
227
236
char *ipiv_array_data = ipiv_array.get_data ();
228
237
std::int64_t *d_ipiv = reinterpret_cast <std::int64_t *>(ipiv_array_data);
229
238
230
239
std::vector<sycl::event> host_task_events;
231
- sycl::event getrf_ev = getrf_fn (exec_q, n, a_array_data, lda, d_ipiv,
240
+ sycl::event getrf_ev = getrf_fn (exec_q, m, n, a_array_data, lda, d_ipiv,
232
241
dev_info, host_task_events, depends);
233
242
234
243
sycl::event args_ev = dpctl::utils::keep_args_alive (
0 commit comments