@@ -22,6 +22,24 @@ oneapi::mkl::transpose convert(onemklTranspose val) {
22
22
}
23
23
}
24
24
25
+ oneapi::mkl::transpose* convert (const onemklTranspose* vals, int64_t size) {
26
+ oneapi::mkl::transpose* result = new oneapi::mkl::transpose[size];
27
+ for (int64_t i = 0 ; i < size; ++i) {
28
+ switch (vals[i]) {
29
+ case ONEMKL_TRANSPOSE_NONTRANS:
30
+ result[i] = oneapi::mkl::transpose::nontrans;
31
+ break ;
32
+ case ONEMKL_TRANSPOSE_TRANS:
33
+ result[i] = oneapi::mkl::transpose::trans;
34
+ break ;
35
+ case ONEMLK_TRANSPOSE_CONJTRANS:
36
+ result[i] = oneapi::mkl::transpose::conjtrans;
37
+ break ;
38
+ }
39
+ }
40
+ return result;
41
+ }
42
+
25
43
oneapi::mkl::uplo convert (onemklUplo val) {
26
44
switch (val) {
27
45
case ONEMKL_UPLO_UPPER:
@@ -31,6 +49,21 @@ oneapi::mkl::uplo convert(onemklUplo val) {
31
49
}
32
50
}
33
51
52
+ oneapi::mkl::uplo* convert (const onemklUplo* vals, int64_t size) {
53
+ oneapi::mkl::uplo* result = new oneapi::mkl::uplo[size];
54
+ for (int64_t i = 0 ; i < size; ++i) {
55
+ switch (vals[i]) {
56
+ case ONEMKL_UPLO_UPPER:
57
+ result[i] = oneapi::mkl::uplo::upper;
58
+ break ;
59
+ case ONEMKL_UPLO_LOWER:
60
+ result[i] = oneapi::mkl::uplo::lower;
61
+ break ;
62
+ }
63
+ }
64
+ return result;
65
+ }
66
+
34
67
oneapi::mkl::diag convert (onemklDiag val) {
35
68
switch (val) {
36
69
case ONEMKL_DIAG_NONUNIT:
@@ -40,6 +73,21 @@ oneapi::mkl::diag convert(onemklDiag val) {
40
73
}
41
74
}
42
75
76
+ oneapi::mkl::diag* convert (const onemklDiag* vals, int64_t size) {
77
+ oneapi::mkl::diag* result = new oneapi::mkl::diag[size];
78
+ for (int64_t i = 0 ; i < size; ++i) {
79
+ switch (vals[i]) {
80
+ case ONEMKL_DIAG_NONUNIT:
81
+ result[i] = oneapi::mkl::diag::nonunit;
82
+ break ;
83
+ case ONEMKL_DIAG_UNIT:
84
+ result[i] = oneapi::mkl::diag::unit;
85
+ break ;
86
+ }
87
+ }
88
+ return result;
89
+ }
90
+
43
91
oneapi::mkl::side convert (onemklSide val) {
44
92
switch (val) {
45
93
case ONEMKL_SIDE_LEFT:
@@ -49,6 +97,21 @@ oneapi::mkl::side convert(onemklSide val) {
49
97
}
50
98
}
51
99
100
+ oneapi::mkl::side* convert (const onemklSide* vals, int64_t size) {
101
+ oneapi::mkl::side* result = new oneapi::mkl::side[size];
102
+ for (int64_t i = 0 ; i < size; ++i) {
103
+ switch (vals[i]) {
104
+ case ONEMKL_SIDE_LEFT:
105
+ result[i] = oneapi::mkl::side::left;
106
+ break ;
107
+ case ONEMKL_SIDE_RIGHT:
108
+ result[i] = oneapi::mkl::side::right;
109
+ break ;
110
+ }
111
+ }
112
+ return result;
113
+ }
114
+
52
115
oneapi::mkl::offset convert (onemklOffset val) {
53
116
switch (val) {
54
117
case ONEMKL_OFFSET_ROW:
@@ -3416,6 +3479,54 @@ extern "C" int64_t onemklZunmtr_scratchpad_size(syclQueue_t device_queue, onemkl
3416
3479
return scratchpad_size;
3417
3480
}
3418
3481
3482
+ extern " C" int onemklSpotrf_batch (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, float **a, int64_t *lda, int64_t group_count, int64_t *group_sizes, float *scratchpad, int64_t scratchpad_size) {
3483
+ auto status = oneapi::mkl::lapack::potrf_batch (device_queue->val , convert (uplo, group_count), n, a, lda, group_count, group_sizes, scratchpad, scratchpad_size, {});
3484
+ __FORCE_MKL_FLUSH__ (status);
3485
+ return 0 ;
3486
+ }
3487
+
3488
+ extern " C" int onemklDpotrf_batch (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, double **a, int64_t *lda, int64_t group_count, int64_t *group_sizes, double *scratchpad, int64_t scratchpad_size) {
3489
+ auto status = oneapi::mkl::lapack::potrf_batch (device_queue->val , convert (uplo, group_count), n, a, lda, group_count, group_sizes, scratchpad, scratchpad_size, {});
3490
+ __FORCE_MKL_FLUSH__ (status);
3491
+ return 0 ;
3492
+ }
3493
+
3494
+ extern " C" int onemklCpotrf_batch (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, float _Complex **a, int64_t *lda, int64_t group_count, int64_t *group_sizes, float _Complex *scratchpad, int64_t scratchpad_size) {
3495
+ auto status = oneapi::mkl::lapack::potrf_batch (device_queue->val , convert (uplo, group_count), n, reinterpret_cast <std::complex<float >**>(a), lda, group_count, group_sizes, reinterpret_cast <std::complex<float >*>(scratchpad), scratchpad_size, {});
3496
+ __FORCE_MKL_FLUSH__ (status);
3497
+ return 0 ;
3498
+ }
3499
+
3500
+ extern " C" int onemklZpotrf_batch (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, double _Complex **a, int64_t *lda, int64_t group_count, int64_t *group_sizes, double _Complex *scratchpad, int64_t scratchpad_size) {
3501
+ auto status = oneapi::mkl::lapack::potrf_batch (device_queue->val , convert (uplo, group_count), n, reinterpret_cast <std::complex<double >**>(a), lda, group_count, group_sizes, reinterpret_cast <std::complex<double >*>(scratchpad), scratchpad_size, {});
3502
+ __FORCE_MKL_FLUSH__ (status);
3503
+ return 0 ;
3504
+ }
3505
+
3506
+ extern " C" int onemklSpotrs_batch (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *nrhs, float **a, int64_t *lda, float **b, int64_t *ldb, int64_t group_count, int64_t *group_sizes, float *scratchpad, int64_t scratchpad_size) {
3507
+ auto status = oneapi::mkl::lapack::potrs_batch (device_queue->val , convert (uplo, group_count), n, nrhs, a, lda, b, ldb, group_count, group_sizes, scratchpad, scratchpad_size, {});
3508
+ __FORCE_MKL_FLUSH__ (status);
3509
+ return 0 ;
3510
+ }
3511
+
3512
+ extern " C" int onemklDpotrs_batch (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *nrhs, double **a, int64_t *lda, double **b, int64_t *ldb, int64_t group_count, int64_t *group_sizes, double *scratchpad, int64_t scratchpad_size) {
3513
+ auto status = oneapi::mkl::lapack::potrs_batch (device_queue->val , convert (uplo, group_count), n, nrhs, a, lda, b, ldb, group_count, group_sizes, scratchpad, scratchpad_size, {});
3514
+ __FORCE_MKL_FLUSH__ (status);
3515
+ return 0 ;
3516
+ }
3517
+
3518
+ extern " C" int onemklCpotrs_batch (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *nrhs, float _Complex **a, int64_t *lda, float _Complex **b, int64_t *ldb, int64_t group_count, int64_t *group_sizes, float _Complex *scratchpad, int64_t scratchpad_size) {
3519
+ auto status = oneapi::mkl::lapack::potrs_batch (device_queue->val , convert (uplo, group_count), n, nrhs, reinterpret_cast <std::complex<float >**>(a), lda, reinterpret_cast <std::complex<float >**>(b), ldb, group_count, group_sizes, reinterpret_cast <std::complex<float >*>(scratchpad), scratchpad_size, {});
3520
+ __FORCE_MKL_FLUSH__ (status);
3521
+ return 0 ;
3522
+ }
3523
+
3524
+ extern " C" int onemklZpotrs_batch (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *nrhs, double _Complex **a, int64_t *lda, double _Complex **b, int64_t *ldb, int64_t group_count, int64_t *group_sizes, double _Complex *scratchpad, int64_t scratchpad_size) {
3525
+ auto status = oneapi::mkl::lapack::potrs_batch (device_queue->val , convert (uplo, group_count), n, nrhs, reinterpret_cast <std::complex<double >**>(a), lda, reinterpret_cast <std::complex<double >**>(b), ldb, group_count, group_sizes, reinterpret_cast <std::complex<double >*>(scratchpad), scratchpad_size, {});
3526
+ __FORCE_MKL_FLUSH__ (status);
3527
+ return 0 ;
3528
+ }
3529
+
3419
3530
extern " C" int onemklSgeinv_batch (syclQueue_t device_queue, int64_t *n, float **a, int64_t *lda, int64_t group_count, int64_t *group_sizes, float *scratchpad, int64_t scratchpad_size) {
3420
3531
auto status = oneapi::mkl::lapack::geinv_batch (device_queue->val , n, a, lda, group_count, group_sizes, scratchpad, scratchpad_size, {});
3421
3532
__FORCE_MKL_FLUSH__ (status);
@@ -3440,6 +3551,30 @@ extern "C" int onemklZgeinv_batch(syclQueue_t device_queue, int64_t *n, double _
3440
3551
return 0 ;
3441
3552
}
3442
3553
3554
+ extern " C" int onemklSgetrs_batch (syclQueue_t device_queue, onemklTranspose *trans, int64_t *n, int64_t *nrhs, float **a, int64_t *lda, int64_t **ipiv, float **b, int64_t *ldb, int64_t group_count, int64_t *group_sizes, float *scratchpad, int64_t scratchpad_size) {
3555
+ auto status = oneapi::mkl::lapack::getrs_batch (device_queue->val , convert (trans, group_count), n, nrhs, a, lda, ipiv, b, ldb, group_count, group_sizes, scratchpad, scratchpad_size, {});
3556
+ __FORCE_MKL_FLUSH__ (status);
3557
+ return 0 ;
3558
+ }
3559
+
3560
+ extern " C" int onemklDgetrs_batch (syclQueue_t device_queue, onemklTranspose *trans, int64_t *n, int64_t *nrhs, double **a, int64_t *lda, int64_t **ipiv, double **b, int64_t *ldb, int64_t group_count, int64_t *group_sizes, double *scratchpad, int64_t scratchpad_size) {
3561
+ auto status = oneapi::mkl::lapack::getrs_batch (device_queue->val , convert (trans, group_count), n, nrhs, a, lda, ipiv, b, ldb, group_count, group_sizes, scratchpad, scratchpad_size, {});
3562
+ __FORCE_MKL_FLUSH__ (status);
3563
+ return 0 ;
3564
+ }
3565
+
3566
+ extern " C" int onemklCgetrs_batch (syclQueue_t device_queue, onemklTranspose *trans, int64_t *n, int64_t *nrhs, float _Complex **a, int64_t *lda, int64_t **ipiv, float _Complex **b, int64_t *ldb, int64_t group_count, int64_t *group_sizes, float _Complex *scratchpad, int64_t scratchpad_size) {
3567
+ auto status = oneapi::mkl::lapack::getrs_batch (device_queue->val , convert (trans, group_count), n, nrhs, reinterpret_cast <std::complex<float >**>(a), lda, ipiv, reinterpret_cast <std::complex<float >**>(b), ldb, group_count, group_sizes, reinterpret_cast <std::complex<float >*>(scratchpad), scratchpad_size, {});
3568
+ __FORCE_MKL_FLUSH__ (status);
3569
+ return 0 ;
3570
+ }
3571
+
3572
+ extern " C" int onemklZgetrs_batch (syclQueue_t device_queue, onemklTranspose *trans, int64_t *n, int64_t *nrhs, double _Complex **a, int64_t *lda, int64_t **ipiv, double _Complex **b, int64_t *ldb, int64_t group_count, int64_t *group_sizes, double _Complex *scratchpad, int64_t scratchpad_size) {
3573
+ auto status = oneapi::mkl::lapack::getrs_batch (device_queue->val , convert (trans, group_count), n, nrhs, reinterpret_cast <std::complex<double >**>(a), lda, ipiv, reinterpret_cast <std::complex<double >**>(b), ldb, group_count, group_sizes, reinterpret_cast <std::complex<double >*>(scratchpad), scratchpad_size, {});
3574
+ __FORCE_MKL_FLUSH__ (status);
3575
+ return 0 ;
3576
+ }
3577
+
3443
3578
extern " C" int onemklSgetri_batch (syclQueue_t device_queue, int64_t *n, float **a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes, float *scratchpad, int64_t scratchpad_size) {
3444
3579
auto status = oneapi::mkl::lapack::getri_batch (device_queue->val , n, a, lda, ipiv, group_count, group_sizes, scratchpad, scratchpad_size, {});
3445
3580
__FORCE_MKL_FLUSH__ (status);
@@ -3512,6 +3647,46 @@ extern "C" int onemklZungqr_batch(syclQueue_t device_queue, int64_t *m, int64_t
3512
3647
return 0 ;
3513
3648
}
3514
3649
3650
+ extern " C" int64_t onemklSpotrf_batch_scratchpad_size (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *lda, int64_t group_count, int64_t *group_sizes) {
3651
+ int64_t scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size<float >(device_queue->val , convert (uplo, group_count), n, lda, group_count, group_sizes);
3652
+ return scratchpad_size;
3653
+ }
3654
+
3655
+ extern " C" int64_t onemklDpotrf_batch_scratchpad_size (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *lda, int64_t group_count, int64_t *group_sizes) {
3656
+ int64_t scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size<double >(device_queue->val , convert (uplo, group_count), n, lda, group_count, group_sizes);
3657
+ return scratchpad_size;
3658
+ }
3659
+
3660
+ extern " C" int64_t onemklCpotrf_batch_scratchpad_size (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *lda, int64_t group_count, int64_t *group_sizes) {
3661
+ int64_t scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size<std::complex<float >>(device_queue->val , convert (uplo, group_count), n, lda, group_count, group_sizes);
3662
+ return scratchpad_size;
3663
+ }
3664
+
3665
+ extern " C" int64_t onemklZpotrf_batch_scratchpad_size (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *lda, int64_t group_count, int64_t *group_sizes) {
3666
+ int64_t scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size<std::complex<double >>(device_queue->val , convert (uplo, group_count), n, lda, group_count, group_sizes);
3667
+ return scratchpad_size;
3668
+ }
3669
+
3670
+ extern " C" int64_t onemklSpotrs_batch_scratchpad_size (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *nrhs, int64_t *lda, int64_t *ldb, int64_t group_count, int64_t *group_sizes) {
3671
+ int64_t scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size<float >(device_queue->val , convert (uplo, group_count), n, nrhs, lda, ldb, group_count, group_sizes);
3672
+ return scratchpad_size;
3673
+ }
3674
+
3675
+ extern " C" int64_t onemklDpotrs_batch_scratchpad_size (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *nrhs, int64_t *lda, int64_t *ldb, int64_t group_count, int64_t *group_sizes) {
3676
+ int64_t scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size<double >(device_queue->val , convert (uplo, group_count), n, nrhs, lda, ldb, group_count, group_sizes);
3677
+ return scratchpad_size;
3678
+ }
3679
+
3680
+ extern " C" int64_t onemklCpotrs_batch_scratchpad_size (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *nrhs, int64_t *lda, int64_t *ldb, int64_t group_count, int64_t *group_sizes) {
3681
+ int64_t scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size<std::complex<float >>(device_queue->val , convert (uplo, group_count), n, nrhs, lda, ldb, group_count, group_sizes);
3682
+ return scratchpad_size;
3683
+ }
3684
+
3685
+ extern " C" int64_t onemklZpotrs_batch_scratchpad_size (syclQueue_t device_queue, onemklUplo *uplo, int64_t *n, int64_t *nrhs, int64_t *lda, int64_t *ldb, int64_t group_count, int64_t *group_sizes) {
3686
+ int64_t scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size<std::complex<double >>(device_queue->val , convert (uplo, group_count), n, nrhs, lda, ldb, group_count, group_sizes);
3687
+ return scratchpad_size;
3688
+ }
3689
+
3515
3690
extern " C" int64_t onemklSgeinv_batch_scratchpad_size (syclQueue_t device_queue, int64_t *n, int64_t *lda, int64_t group_count, int64_t *group_sizes) {
3516
3691
int64_t scratchpad_size = oneapi::mkl::lapack::geinv_batch_scratchpad_size<float >(device_queue->val , n, lda, group_count, group_sizes);
3517
3692
return scratchpad_size;
@@ -3532,6 +3707,26 @@ extern "C" int64_t onemklZgeinv_batch_scratchpad_size(syclQueue_t device_queue,
3532
3707
return scratchpad_size;
3533
3708
}
3534
3709
3710
+ extern " C" int64_t onemklSgetrs_batch_scratchpad_size (syclQueue_t device_queue, onemklTranspose *trans, int64_t *n, int64_t *nrhs, int64_t *lda, int64_t *ldb, int64_t group_count, int64_t *group_sizes) {
3711
+ int64_t scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size<float >(device_queue->val , convert (trans, group_count), n, nrhs, lda, ldb, group_count, group_sizes);
3712
+ return scratchpad_size;
3713
+ }
3714
+
3715
+ extern " C" int64_t onemklDgetrs_batch_scratchpad_size (syclQueue_t device_queue, onemklTranspose *trans, int64_t *n, int64_t *nrhs, int64_t *lda, int64_t *ldb, int64_t group_count, int64_t *group_sizes) {
3716
+ int64_t scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size<double >(device_queue->val , convert (trans, group_count), n, nrhs, lda, ldb, group_count, group_sizes);
3717
+ return scratchpad_size;
3718
+ }
3719
+
3720
+ extern " C" int64_t onemklCgetrs_batch_scratchpad_size (syclQueue_t device_queue, onemklTranspose *trans, int64_t *n, int64_t *nrhs, int64_t *lda, int64_t *ldb, int64_t group_count, int64_t *group_sizes) {
3721
+ int64_t scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size<std::complex<float >>(device_queue->val , convert (trans, group_count), n, nrhs, lda, ldb, group_count, group_sizes);
3722
+ return scratchpad_size;
3723
+ }
3724
+
3725
+ extern " C" int64_t onemklZgetrs_batch_scratchpad_size (syclQueue_t device_queue, onemklTranspose *trans, int64_t *n, int64_t *nrhs, int64_t *lda, int64_t *ldb, int64_t group_count, int64_t *group_sizes) {
3726
+ int64_t scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size<std::complex<double >>(device_queue->val , convert (trans, group_count), n, nrhs, lda, ldb, group_count, group_sizes);
3727
+ return scratchpad_size;
3728
+ }
3729
+
3535
3730
extern " C" int64_t onemklSgetri_batch_scratchpad_size (syclQueue_t device_queue, int64_t *n, int64_t *lda, int64_t group_count, int64_t *group_sizes) {
3536
3731
int64_t scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<float >(device_queue->val , n, lda, group_count, group_sizes);
3537
3732
return scratchpad_size;
0 commit comments