11
11
#include < ATen/ops/_linalg_check_errors_native.h>
12
12
#include < ATen/ops/empty.h>
13
13
#include < ATen/ops/from_blob.h>
14
+ #include < ATen/ops/zeros_like.h>
14
15
15
16
#include < comm/SYCLContext.h>
16
17
#include < comm/TensorInfo.h>
@@ -37,21 +38,37 @@ static oneapi::mkl::transpose to_blas_(TransposeType trans) {
37
38
TORCH_INTERNAL_ASSERT (false , " Invalid transpose type" );
38
39
}
39
40
40
- void error_handle (int32_t * infos, const oneapi::mkl::lapack::batch_error& be) {
41
+ void error_handle (
42
+ int32_t * info_cpu,
43
+ const oneapi::mkl::lapack::batch_error& be) {
41
44
auto errs = be.exceptions ();
42
45
auto ids = be.ids ();
46
+
47
+ if (!errs.size ()) {
48
+ TORCH_WARN (
49
+ " Caught lapack exception:\n What: " , be.what (), " \n Info: " , be.info ());
50
+ for (auto & i : ids) {
51
+ TORCH_WARN (" Error in matrix #" , i);
52
+ info_cpu[i] = 1 ;
53
+ }
54
+ return ;
55
+ }
56
+
43
57
for (auto & i : ids) {
44
58
try {
45
59
std::rethrow_exception (errs[i]);
46
60
} catch (const oneapi::mkl::lapack::exception& e) {
47
- std::cout << " Cathed lapack exception:"
48
- << " \n What: " << e.what () << " \n Info: " << e.info ()
49
- << " \n Detail: " << e.detail () << std::endl;
50
- infos[i] = e.info ();
61
+ TORCH_WARN (
62
+ " Caught lapack exception:\n What: " ,
63
+ e.what (),
64
+ " \n Info: " ,
65
+ e.info (),
66
+ " \n Detail: " ,
67
+ e.detail ());
68
+ info_cpu[i] = e.info ();
51
69
} catch (const sycl::exception& e) {
52
- std::cout << " Catched SYCL exception:"
53
- << " \n What: " << e.what () << " \n Info: -1" << std::endl;
54
- infos[i] = -1 ;
70
+ TORCH_WARN (" Caught SYCL exception:\n What: " , e.what (), " \n Info: -1" );
71
+ info_cpu[i] = -1 ;
55
72
}
56
73
}
57
74
}
@@ -372,7 +389,7 @@ template <typename scalar_t>
372
389
static void apply_lu_xpu_ (
373
390
const Tensor& self_,
374
391
Tensor& pivots_,
375
- int32_t * infos_ ) {
392
+ int32_t * info_data ) {
376
393
// do nothing if empty input.
377
394
if (self_.numel () == 0 )
378
395
return ;
@@ -403,7 +420,7 @@ static void apply_lu_xpu_(
403
420
(scalar_t *)(scratchpad_at.data_ptr ()),
404
421
scratchpadsize);
405
422
} catch (const oneapi::mkl::lapack::batch_error& be) {
406
- error_handle (infos_ , be);
423
+ error_handle (info_data , be);
407
424
}
408
425
}
409
426
@@ -436,8 +453,8 @@ static void apply_lu_solve_xpu_(
436
453
int64_t * ipiv = pivots.data_ptr <int64_t >();
437
454
scalar_t * b = b_.data_ptr <scalar_t >();
438
455
439
- std::vector<int32_t > infos (batch_size, 0 );
440
- int32_t * infos_ = infos .data ();
456
+ std::vector<int32_t > info_vec (batch_size, 0 );
457
+ int32_t * info_data = info_vec .data ();
441
458
442
459
auto execute_mkl_getrs =
443
460
[&](scalar_t * a, scalar_t * b, int64_t * ipiv, int64_t batch_size) {
@@ -471,7 +488,7 @@ static void apply_lu_solve_xpu_(
471
488
scratchpad_at.data_ptr <scalar_t >(),
472
489
scratchpad_size);
473
490
} catch (oneapi::mkl::lapack::batch_error be) {
474
- error_handle (infos_ , be);
491
+ error_handle (info_data , be);
475
492
}
476
493
};
477
494
@@ -529,17 +546,18 @@ void lu_factor_mkl(
529
546
" linalg.lu_factor: LU without pivoting is not implemented on the XPU" );
530
547
531
548
// handle the info
532
- info. zero_ ( );
533
- int32_t * infos_data = info .data_ptr <int32_t >();
549
+ Tensor info_ = at::zeros_like ( info, Device (at:: kCPU ) );
550
+ int32_t * info_data = info_ .data_ptr <int32_t >();
534
551
535
552
// oneMKL requires Long for pivots but PyTorch provides Int
536
553
Tensor pivots_ = at::empty (pivots.sizes (), pivots.options ().dtype (kLong ));
537
554
538
555
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES (LU.scalar_type (), " lu_xpu" , [&] {
539
- apply_lu_xpu_<scalar_t >(LU, pivots_, infos_data );
556
+ apply_lu_xpu_<scalar_t >(LU, pivots_, info_data );
540
557
});
541
558
542
- // Copy to original pivots tensor
559
+ // Copy to original info and pivots tensor
560
+ info.copy_ (info_);
543
561
pivots.copy_ (pivots_);
544
562
}
545
563
0 commit comments