11
11
#include < ATen/ops/_linalg_check_errors_native.h>
12
12
#include < ATen/ops/empty.h>
13
13
#include < ATen/ops/from_blob.h>
14
+ #include < ATen/ops/zeros_like.h>
14
15
15
16
#include < comm/SYCLContext.h>
16
17
#include < comm/TensorInfo.h>
@@ -40,6 +41,18 @@ static oneapi::mkl::transpose to_blas_(TransposeType trans) {
40
41
void error_handle (int32_t * infos, const oneapi::mkl::lapack::batch_error& be) {
41
42
auto errs = be.exceptions ();
42
43
auto ids = be.ids ();
44
+
45
+ if (!errs.size ()) {
46
+ std::cout << " Cathed lapack exception:"
47
+ << " \n What: " << be.what () << " \n Info: " << be.info ()
48
+ << std::endl;
49
+ for (auto & i : ids) {
50
+ std::cout << " Error in martix #" << i << std::endl;
51
+ infos[i] = 1 ;
52
+ }
53
+ return ;
54
+ }
55
+
43
56
for (auto & i : ids) {
44
57
try {
45
58
std::rethrow_exception (errs[i]);
@@ -529,8 +542,8 @@ void lu_factor_mkl(
529
542
" linalg.lu_factor: LU without pivoting is not implemented on the XPU" );
530
543
531
544
// handle the info
532
- info. zero_ ( );
533
- int32_t * infos_data = info .data_ptr <int32_t >();
545
+ Tensor info_ = at::zeros_like ( info, Device (at:: kCPU ) );
546
+ int32_t * infos_data = info_ .data_ptr <int32_t >();
534
547
535
548
// oneMKL requires Long for pivots but PyTorch provides Int
536
549
Tensor pivots_ = at::empty (pivots.sizes (), pivots.options ().dtype (kLong ));
@@ -539,7 +552,8 @@ void lu_factor_mkl(
539
552
apply_lu_xpu_<scalar_t >(LU, pivots_, infos_data);
540
553
});
541
554
542
- // Copy to original pivots tensor
555
+ // Copy to original info and pivots tensor
556
+ info.copy_ (info_);
543
557
pivots.copy_ (pivots_);
544
558
}
545
559
0 commit comments