1+ // ===================
2+ // Author: Peize Lin
3+ // date: 2022.12.25
4+ // ===================
5+
6+ #pragma once
7+
8+ #include " Lapack-Fortran.h"
9+
10+ #include < string>
11+ #include < stdexcept>
12+
13+
14+ #ifdef __MKL_RI
15+ #include < mkl_trans.h>
16+ #endif
17+
18+ #define LAPACK_INFO_CHECK (x ) if (const int info=(x)) throw std::runtime_error (" info=" +std::to_string(info)+".\n"+std::string(__FILE__)+" line "+std::to_string(__LINE__));
19+
20+ namespace RI
21+ {
22+
23+ namespace Lapack_Interface
24+ {
25+ // potrf computes the Cholesky factorization of a real symmetric positive definite matrix
26+ inline int potrf ( const char &uplo, const int &n, float *const A, const int &lda )
27+ {
28+ int info;
29+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
30+ spotrf_ ( &uplo_changed, &n, A, &lda, &info );
31+ return info;
32+ }
33+ inline int potrf ( const char &uplo, const int &n, double *const A, const int &lda )
34+ {
35+ int info;
36+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
37+ dpotrf_ ( &uplo_changed, &n, A, &lda, &info );
38+ return info;
39+ }
40+ inline int potrf ( const char &uplo, const int &n, std::complex <float >*const A, const int &lda )
41+ {
42+ int info;
43+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
44+ cpotrf_ ( &uplo_changed, &n, A, &lda, &info );
45+ return info;
46+ }
47+ inline int potrf ( const char &uplo, const int &n, std::complex <double >*const A, const int &lda )
48+ {
49+ int info;
50+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
51+ zpotrf_ ( &uplo_changed, &n, A, &lda, &info );
52+ return info;
53+ }
54+
55+ // potri takes potrf's output to perform matrix inversion
56+ inline int potri ( const char &uplo, const int &n, float *const A, const int &lda )
57+ {
58+ int info;
59+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
60+ spotri_ ( &uplo_changed, &n, A, &lda, &info);
61+ return info;
62+ }
63+ inline int potri ( const char &uplo, const int &n, double *const A, const int &lda )
64+ {
65+ int info;
66+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
67+ dpotri_ ( &uplo_changed, &n, A, &lda, &info);
68+ return info;
69+ }
70+ inline int potri ( const char &uplo, const int &n, std::complex <float >*const A, const int &lda )
71+ {
72+ int info;
73+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
74+ cpotri_ ( &uplo_changed, &n, A, &lda, &info);
75+ return info;
76+ }
77+ inline int potri ( const char &uplo, const int &n, std::complex <double >*const A, const int &lda )
78+ {
79+ int info;
80+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
81+ zpotri_ ( &uplo_changed, &n, A, &lda, &info);
82+ return info;
83+ }
84+
85+ // solve the eigenproblem Ax=ex, where A is Symmetric
86+ inline int syev (const char &jobz, const char &uplo,
87+ const int &n, float *const A, const int &lda, float *const W,
88+ float *const WORK, const int &lwork)
89+ {
90+ int info;
91+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
92+ ssyev_ (&jobz, &uplo_changed, &n, A, &lda, W, WORK, &lwork, &info);
93+ return info;
94+ }
95+ inline int syev (const char &jobz, const char &uplo,
96+ const int &n, double *const A, const int &lda, double *const W,
97+ double *const WORK, const int &lwork)
98+ {
99+ int info;
100+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
101+ dsyev_ (&jobz, &uplo_changed, &n, A, &lda, W, WORK, &lwork, &info);
102+ return info;
103+ }
104+ // solve the eigenproblem Ax=ex, where A is Hermitian
105+ inline int heev (const char &jobz, const char &uplo,
106+ const int &n, std::complex <float >*const A, const int &lda, float *const W,
107+ std::complex <float >*const WORK, const int &lwork, float *const RWORK)
108+ {
109+ int info;
110+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
111+ cheev_ (&jobz, &uplo_changed, &n, A, &lda, W, WORK, &lwork, RWORK, &info);
112+ return info;
113+ }
114+ inline int heev (const char &jobz, const char &uplo,
115+ const int &n, std::complex <double >*const A, const int &lda, double *const W,
116+ std::complex <double >*const WORK, const int &lwork, double *const RWORK)
117+ {
118+ int info;
119+ const char uplo_changed = Blas_Interface::change_uplo (uplo);
120+ zheev_ (&jobz, &uplo_changed, &n, A, &lda, W, WORK, &lwork, RWORK, &info);
121+ return info;
122+ }
123+
124+ // solve the eigenproblem Ax=ex, where A is Hermitian
125+ template <typename T,
126+ typename std::enable_if< std::is_arithmetic<T>::value,int >::type =0 >
127+ inline int heev (const char &jobz, const char &uplo,
128+ const int &n, T*const A, const int &lda, T*const W)
129+ {
130+ T work_tmp=100 ;
131+ constexpr int minus_one = -1 ;
132+ LAPACK_INFO_CHECK (syev (jobz, uplo, n, A, lda, W, &work_tmp, minus_one)); // get best lwork
133+
134+ const int lwork = work_tmp;
135+ std::vector<T> WORK (std::max (1 ,lwork));
136+ return syev (jobz, uplo, n, A, lda, W, WORK.data (), lwork);
137+ }
138+ template <typename T,
139+ typename std::enable_if< std::is_arithmetic<T>::value,int >::type =0 >
140+ inline int heev (const char &jobz, const char &uplo,
141+ const int &n, std::complex <T>*const A, const int &lda, T*const W)
142+ {
143+ std::vector<T> RWORK (std::max (1 ,3 *n-2 ));
144+
145+ std::complex <T> work_tmp;
146+ constexpr int minus_one = -1 ;
147+ LAPACK_INFO_CHECK (heev (jobz, uplo, n, A, lda, W, &work_tmp, minus_one, RWORK.data ())); // get best lwork
148+
149+ const int lwork = std::real (work_tmp);
150+ std::vector<std::complex <T>> WORK (std::max (1 ,lwork));
151+ return heev (jobz, uplo, n, A, lda, W, WORK.data (), lwork, RWORK.data ());
152+ }
153+ }
154+
155+ }
156+
157+ #undef LAPACK_INFO_CHECK
0 commit comments