@@ -46,6 +46,7 @@ namespace type_utils = dpctl::tensor::type_utils;
46
46
typedef sycl::event (*getrf_batch_impl_fn_ptr_t )(
47
47
sycl::queue &,
48
48
std::int64_t ,
49
+ std::int64_t ,
49
50
char *,
50
51
std::int64_t ,
51
52
std::int64_t ,
@@ -61,6 +62,7 @@ static getrf_batch_impl_fn_ptr_t
61
62
62
63
template <typename T>
63
64
static sycl::event getrf_batch_impl (sycl::queue &exec_q,
65
+ std::int64_t m,
64
66
std::int64_t n,
65
67
char *in_a,
66
68
std::int64_t lda,
@@ -77,7 +79,7 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
77
79
T *a = reinterpret_cast <T *>(in_a);
78
80
79
81
const std::int64_t scratchpad_size =
80
- mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, n , n, lda, stride_a,
82
+ mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, m , n, lda, stride_a,
81
83
stride_ipiv, batch_size);
82
84
T *scratchpad = nullptr ;
83
85
@@ -91,11 +93,11 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
91
93
92
94
getrf_batch_event = mkl_lapack::getrf_batch (
93
95
exec_q,
94
- n , // The order of each square matrix in the batch; (0 ≤ n ).
96
+ m , // The number of rows in each matrix in the batch; (0 ≤ m ).
95
97
// It must be a non-negative integer.
96
98
n, // The number of columns in each matrix in the batch; (0 ≤ n).
97
99
// It must be a non-negative integer.
98
- a, // Pointer to the batch of square matrices, each of size (n x n).
100
+ a, // Pointer to the batch of input matrices, each of size (m x n).
99
101
lda, // The leading dimension of each matrix in the batch.
100
102
stride_a, // Stride between consecutive matrices in the batch.
101
103
ipiv, // Pointer to the array of pivot indices for each matrix in
@@ -179,6 +181,7 @@ std::pair<sycl::event, sycl::event>
179
181
const dpctl::tensor::usm_ndarray &a_array,
180
182
const dpctl::tensor::usm_ndarray &ipiv_array,
181
183
py::list dev_info,
184
+ std::int64_t m,
182
185
std::int64_t n,
183
186
std::int64_t stride_a,
184
187
std::int64_t stride_ipiv,
@@ -191,21 +194,21 @@ std::pair<sycl::event, sycl::event>
191
194
if (a_array_nd < 3 ) {
192
195
throw py::value_error (
193
196
" The input array has ndim=" + std::to_string (a_array_nd) +
194
- " , but an array with ndim >= 3 is expected. " );
197
+ " , but an array with ndim >= 3 is expected" );
195
198
}
196
199
197
200
if (ipiv_array_nd != 2 ) {
198
201
throw py::value_error (" The array of pivot indices has ndim=" +
199
202
std::to_string (ipiv_array_nd) +
200
- " , but a 2-dimensional array is expected. " );
203
+ " , but a 2-dimensional array is expected" );
201
204
}
202
205
203
206
const int dev_info_size = py::len (dev_info);
204
207
if (dev_info_size != batch_size) {
205
208
throw py::value_error (" The size of 'dev_info' (" +
206
209
std::to_string (dev_info_size) +
207
210
" ) does not match the expected batch size (" +
208
- std::to_string (batch_size) + " ). " );
211
+ std::to_string (batch_size) + " )" );
209
212
}
210
213
211
214
// check compatibility of execution queue and allocation queue
@@ -221,10 +224,11 @@ std::pair<sycl::event, sycl::event>
221
224
}
222
225
223
226
bool is_a_array_c_contig = a_array.is_c_contiguous ();
227
+ bool is_a_array_f_contig = a_array.is_f_contiguous ();
224
228
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous ();
225
- if (!is_a_array_c_contig) {
229
+ if (!is_a_array_c_contig && !is_a_array_f_contig ) {
226
230
throw py::value_error (" The input array "
227
- " must be C- contiguous" );
231
+ " must be must contiguous" );
228
232
}
229
233
if (!is_ipiv_array_c_contig) {
230
234
throw py::value_error (" The array of pivot indices "
@@ -240,27 +244,34 @@ std::pair<sycl::event, sycl::event>
240
244
if (getrf_batch_fn == nullptr ) {
241
245
throw py::value_error (
242
246
" No getrf_batch implementation defined for the provided type "
243
- " of the input matrix. " );
247
+ " of the input matrix" );
244
248
}
245
249
246
250
auto ipiv_types = dpctl_td_ns::usm_ndarray_types ();
247
251
int ipiv_array_type_id =
248
252
ipiv_types.typenum_to_lookup_id (ipiv_array.get_typenum ());
249
253
250
254
if (ipiv_array_type_id != static_cast <int >(dpctl_td_ns::typenum_t ::INT64)) {
251
- throw py::value_error (" The type of 'ipiv_array' must be int64." );
255
+ throw py::value_error (" The type of 'ipiv_array' must be int64" );
256
+ }
257
+
258
+ const py::ssize_t *ipiv_array_shape = ipiv_array.get_shape_raw ();
259
+ if (ipiv_array_shape[0 ] != batch_size ||
260
+ ipiv_array_shape[1 ] != std::min (m, n)) {
261
+ throw py::value_error (
262
+ " The shape of 'ipiv_array' must be (batch_size, min(m, n))" );
252
263
}
253
264
254
265
char *a_array_data = a_array.get_data ();
255
- const std::int64_t lda = std::max<size_t >(1UL , n );
266
+ const std::int64_t lda = std::max<size_t >(1UL , m );
256
267
257
268
char *ipiv_array_data = ipiv_array.get_data ();
258
269
std::int64_t *d_ipiv = reinterpret_cast <std::int64_t *>(ipiv_array_data);
259
270
260
271
std::vector<sycl::event> host_task_events;
261
272
sycl::event getrf_batch_ev = getrf_batch_fn (
262
- exec_q, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, batch_size ,
263
- dev_info, host_task_events, depends);
273
+ exec_q, m, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv,
274
+ batch_size, dev_info, host_task_events, depends);
264
275
265
276
sycl::event args_ev = dpctl::utils::keep_args_alive (
266
277
exec_q, {a_array, ipiv_array}, host_task_events);
0 commit comments