|
6 | 6 | #include <cuda_runtime.h> |
7 | 7 | #include <thrust/complex.h> |
8 | 8 |
|
| 9 | +#include <cassert> |
| 10 | + |
| 11 | + |
9 | 12 | namespace container { |
10 | 13 | namespace kernels { |
11 | 14 |
|
@@ -101,22 +104,100 @@ struct lapack_heevd<T, DEVICE_GPU> { |
101 | 104 | } |
102 | 105 | }; |
103 | 106 |
|
| 107 | +template <typename T> |
| 108 | +struct lapack_heevx<T, DEVICE_GPU> { |
| 109 | + using Real = typename GetTypeReal<T>::type; |
| 110 | + void operator()( |
| 111 | + const int n, |
| 112 | + const int lda, |
| 113 | + const T *d_Mat, |
| 114 | + const int neig, |
| 115 | + Real *d_eigen_val, |
| 116 | + T *d_eigen_vec) |
| 117 | + { |
| 118 | + assert(n <= lda); |
| 119 | + // copy d_Mat to d_eigen_vec, and results will be overwritten into d_eigen_vec |
| 120 | + // by cuSolver |
| 121 | + cudaErrcheck(cudaMemcpy(d_eigen_vec, d_Mat, sizeof(T) * n * lda, cudaMemcpyDeviceToDevice)); |
| 122 | + |
| 123 | + int meig = 0; |
| 124 | + |
| 125 | + cuSolverConnector::heevdx( |
| 126 | + cusolver_handle, |
| 127 | + n, |
| 128 | + lda, |
| 129 | + d_eigen_vec, |
| 130 | + 'V', // jobz: compute vectors |
| 131 | + 'L', // uplo: lower triangle |
| 132 | + 'I', // range: by index |
| 133 | + 1, neig, // il, iu |
| 134 | + Real(0), Real(0), // vl, vu (unused) |
| 135 | + d_eigen_val, |
| 136 | + &meig |
| 137 | + ); |
| 138 | + |
| 139 | + } |
| 140 | +}; |
104 | 141 | template <typename T> |
105 | 142 | struct lapack_hegvd<T, DEVICE_GPU> { |
106 | 143 | using Real = typename GetTypeReal<T>::type; |
107 | 144 | void operator()( |
108 | | - const int& itype, |
109 | | - const char& jobz, |
110 | | - const char& uplo, |
| 145 | + const int dim, |
| 146 | + const int lda, |
111 | 147 | T* Mat_A, |
112 | 148 | T* Mat_B, |
113 | | - const int& dim, |
114 | | - Real* eigen_val) |
| 149 | + Real* eigen_val, |
| 150 | + T *eigen_vec) |
115 | 151 | { |
116 | | - cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val); |
| 152 | + const int itype = 1; |
| 153 | + const char jobz = 'V'; |
| 154 | + const char uplo = 'L'; |
| 155 | + cudaErrcheck(cudaMemcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice)); |
| 156 | + |
| 157 | + // prevent B from being overwritten by Cholesky |
| 158 | + T *d_B_backup = nullptr; |
| 159 | + cudaErrcheck(cudaMalloc(&d_B_backup, sizeof(T) * dim * lda)); |
| 160 | + cudaErrcheck(cudaMemcpy(d_B_backup, Mat_B, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice)); |
| 161 | + |
| 162 | + cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim, |
| 163 | + eigen_vec, lda, |
| 164 | + d_B_backup, lda, |
| 165 | + eigen_val); |
| 166 | + cudaErrcheck(cudaFree(d_B_backup)); |
| 167 | + } |
| 168 | +}; |
| 169 | + |
| 170 | +template <typename T> |
| 171 | +struct lapack_hegvx<T, DEVICE_GPU> { |
| 172 | + using Real = typename GetTypeReal<T>::type; |
| 173 | + void operator()( |
| 174 | + const int n, |
| 175 | + const int lda, |
| 176 | + T *A, |
| 177 | + T *B, |
| 178 | + const int m, |
| 179 | + Real *eigen_val, |
| 180 | + T *eigen_vec) |
| 181 | + { |
| 182 | + const int itype = 1; |
| 183 | + const char jobz = 'V'; |
| 184 | + const char range = 'I'; |
| 185 | + const char uplo = 'U'; |
| 186 | + int meig = 0; |
| 187 | + |
| 188 | + // this hegvdx will protect the input A, B from being overwritten |
| 189 | + // and write the eigenvectors into eigen_vec. |
| 190 | + cuSolverConnector::hegvdx(cusolver_handle, |
| 191 | + itype, jobz, range, uplo, |
| 192 | + n, lda, A, B, |
| 193 | + Real(0), Real(0), |
| 194 | + 1, m, &meig, |
| 195 | + eigen_val, eigen_vec); |
117 | 196 | } |
118 | 197 | }; |
119 | 198 |
|
| 199 | + |
| 200 | + |
120 | 201 | template <typename T> |
121 | 202 | struct lapack_getrf<T, DEVICE_GPU> { |
122 | 203 | void operator()( |
@@ -180,11 +261,21 @@ template struct lapack_heevd<double, DEVICE_GPU>; |
180 | 261 | template struct lapack_heevd<std::complex<float>, DEVICE_GPU>; |
181 | 262 | template struct lapack_heevd<std::complex<double>, DEVICE_GPU>; |
182 | 263 |
|
| 264 | +template struct lapack_heevx<float, DEVICE_GPU>; |
| 265 | +template struct lapack_heevx<double, DEVICE_GPU>; |
| 266 | +template struct lapack_heevx<std::complex<float>, DEVICE_GPU>; |
| 267 | +template struct lapack_heevx<std::complex<double>, DEVICE_GPU>; |
| 268 | + |
183 | 269 | template struct lapack_hegvd<float, DEVICE_GPU>; |
184 | 270 | template struct lapack_hegvd<double, DEVICE_GPU>; |
185 | 271 | template struct lapack_hegvd<std::complex<float>, DEVICE_GPU>; |
186 | 272 | template struct lapack_hegvd<std::complex<double>, DEVICE_GPU>; |
187 | 273 |
|
| 274 | +template struct lapack_hegvx<float, DEVICE_GPU>; |
| 275 | +template struct lapack_hegvx<double, DEVICE_GPU>; |
| 276 | +template struct lapack_hegvx<std::complex<float>, DEVICE_GPU>; |
| 277 | +template struct lapack_hegvx<std::complex<double>, DEVICE_GPU>; |
| 278 | + |
188 | 279 | template struct lapack_getrf<float, DEVICE_GPU>; |
189 | 280 | template struct lapack_getrf<double, DEVICE_GPU>; |
190 | 281 | template struct lapack_getrf<std::complex<float>, DEVICE_GPU>; |
|
0 commit comments