Skip to content

Commit c42bbc0

Browse files
committed
Add lapack_hegvx
1 parent 233254a commit c42bbc0

File tree

4 files changed

+540
-324
lines changed

4 files changed

+540
-324
lines changed

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -149,26 +149,40 @@ struct lapack_hegvd<T, DEVICE_GPU> {
149149
const char jobz = 'V';
150150
const char uplo = 'U';
151151
cudaErrcheck(cudaMemcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice));
152-
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim, eigen_vec, lda, Mat_B, lda, eigen_val);
152+
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim,
153+
eigen_vec, lda, Mat_B, lda,
154+
eigen_val);
153155
}
154156
};
155157

156-
// template <typename T>
157-
// struct lapack_hegvx<T, DEVICE_GPU> {
158-
// using Real = typename GetTypeReal<T>::type;
159-
// void operator()(
160-
// const int n,
161-
// const int lda,
162-
// T *A,
163-
// T *B,
164-
// const int m,
165-
// Real *eigen_val,
166-
// T *eigen_vec)
167-
// {
168-
// cuSolverConnector::hegvx(cusolver_handle, n, lda, A, B, m, eigen_val, eigen_vec);
169-
// }
170-
// };
171-
//
158+
template <typename T>
159+
struct lapack_hegvx<T, DEVICE_GPU> {
160+
using Real = typename GetTypeReal<T>::type;
161+
void operator()(
162+
const int n,
163+
const int lda,
164+
T *A,
165+
T *B,
166+
const int m,
167+
Real *eigen_val,
168+
T *eigen_vec)
169+
{
170+
const int itype = 1;
171+
const char jobz = 'V';
172+
const char range = 'I';
173+
const char uplo = 'U';
174+
int meig = 0;
175+
cuSolverConnector::hegvdx(cusolver_handle,
176+
itype, jobz, range, uplo,
177+
n, lda, A, B,
178+
Real(0), Real(0),
179+
1, m, &meig,
180+
eigen_val, eigen_vec);
181+
}
182+
};
183+
184+
185+
172186
template <typename T>
173187
struct lapack_getrf<T, DEVICE_GPU> {
174188
void operator()(
@@ -242,6 +256,11 @@ template struct lapack_hegvd<double, DEVICE_GPU>;
242256
template struct lapack_hegvd<std::complex<float>, DEVICE_GPU>;
243257
template struct lapack_hegvd<std::complex<double>, DEVICE_GPU>;
244258

259+
template struct lapack_hegvx<float, DEVICE_GPU>;
260+
template struct lapack_hegvx<double, DEVICE_GPU>;
261+
template struct lapack_hegvx<std::complex<float>, DEVICE_GPU>;
262+
template struct lapack_hegvx<std::complex<double>, DEVICE_GPU>;
263+
245264
template struct lapack_getrf<float, DEVICE_GPU>;
246265
template struct lapack_getrf<double, DEVICE_GPU>;
247266
template struct lapack_getrf<std::complex<float>, DEVICE_GPU>;

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 105 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ struct lapack_hegvd<T, DEVICE_CPU> {
217217

218218
const int itype = 1;
219219
const char jobz = 'V';
220-
const char uplo = 'U';
220+
const char uplo = 'L';
221221
int info = 0;
222222
int lwork = std::max(2 * dim + dim * dim, 1 + 6 * dim + 2 * dim * dim);
223223
Tensor work(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {lwork});
@@ -240,80 +240,105 @@ struct lapack_hegvd<T, DEVICE_CPU> {
240240
};
241241

242242

243-
// template <typename T>
244-
// struct lapack_hegvx<T, DEVICE_CPU> {
245-
// using Real = typename GetTypeReal<T>::type;
246-
// void operator()(
247-
// const int n,
248-
// const int lda,
249-
// T *A,
250-
// T *B,
251-
// const int m,
252-
// Real *eigen_val,
253-
// T *eigen_vec)
254-
// {
255-
// int info = 0;
256-
257-
// int mm = m;
258-
259-
// int lwork = -1;
260-
261-
// T *work = new T[1];
262-
// Real *rwork = new Real[7 * n];
263-
// int *iwork = new int[5 * n];
264-
// int *ifail = new int[n];
265-
266-
// // set lwork = -1 to query optimal work size
267-
// lapackConnector::hegvx(1, // ITYPE = 1: A*x = (lambda)*B*x
268-
// 'V', 'I', 'U',
269-
// n, A, lda, B, lda,
270-
// 0.0, 0.0,
271-
// 1, m, 0.0, mm,
272-
// eigen_val, eigen_vec, lda,
273-
// work,
274-
// lwork, // lwork = 1, query optimal size.
275-
// rwork, iwork, ifail,
276-
// info);
277-
278-
// // !> If LWORK = -1, then a workspace query is assumed; the routine
279-
// // !> only calculates the optimal size of the WORK array, returns
280-
// // !> this value as the first entry of the WORK array.
281-
// lwork = int(get_real(work[0]));
282-
// delete[] work;
283-
// work = new T[lwork];
284-
285-
// lapackConnector::hegvx(
286-
// 1, // ITYPE = 1: A*x = (lambda)*B*x
287-
// 'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
288-
// 'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
289-
// 'U', // UPLO = 'U': Upper triangles of A and B are stored.
290-
// n, // order of the matrices A and B.
291-
// A, // A is COMPLEX*16 array dimension (LDA, N)
292-
// lda, // leading dimension of the array A.
293-
// B, // B is COMPLEX*16 array, dimension (LDB, N)
294-
// lda, // assume that leading dimension of B is the same as A.
295-
// 0.0, // VL, Not referenced if RANGE = 'A' or 'I'.
296-
// 0.0, // VU, Not referenced if RANGE = 'A' or 'I'.
297-
// 1, // IL: If RANGE='I', the index of the smallest eigenvalue to be returned. 1 <= IL <= IU <= N,
298-
// m, // IU: If RANGE='I', the index of the largest eigenvalue to be returned. 1 <= IL <= IU <= N,
299-
// 0.0, // ABSTOL
300-
// mm, // M: The total number of eigenvalues found. 0 <= M <= N. if RANGE = 'I', M = IU-IL+1.
301-
// eigen_val, // W store eigenvalues
302-
// eigen_vec, // Z store eigenvector
303-
// lda, // LDZ: The leading dimension of the array Z.
304-
// work,
305-
// lwork,
306-
// rwork,
307-
// iwork,
308-
// ifail,
309-
// info);
310-
311-
// delete[] work;
312-
// delete[] rwork;
313-
// delete[] iwork;
314-
// delete[] ifail;
315-
// }
316-
// };
243+
template <typename T>
244+
struct lapack_hegvx<T, DEVICE_CPU> {
245+
using Real = typename GetTypeReal<T>::type;
246+
void operator()(
247+
const int n,
248+
const int lda,
249+
T *Mat_A,
250+
T *Mat_B,
251+
const int m,
252+
Real *eigen_val,
253+
T *eigen_vec)
254+
{
255+
// first copy Mat_A and Mat_B to auxiliary memory
256+
// to avoid the origin block being overwritten by hegvx
257+
Tensor aux_A(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {n * lda});
258+
std::copy(Mat_A, Mat_A + n * lda, aux_A.data<T>());
259+
Tensor aux_B(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {n * lda});
260+
std::copy(Mat_B, Mat_B + n * lda, aux_B.data<T>());
261+
262+
const int itype = 1; // ITYPE = 1: A*x = (lambda)*B*x
263+
const char jobz = 'V';// JOBZ = 'V': Compute eigenvalues and eigenvectors.
264+
const char range = 'I'; // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
265+
const char uplo = 'L'; // UPLO = 'L': Lower triangles of A and B are stored.
266+
267+
const int il = 1;
268+
const int iu = m;
269+
int found = m; // Found, should be iu - il + 1
270+
int info = 0;
271+
272+
int lwork = -1;
273+
274+
T work_query;
275+
Real rwork_query;
276+
277+
// set lwork = -1 to query optimal work size
278+
lapackConnector::hegvx(
279+
itype, jobz, range, uplo,
280+
n,
281+
aux_A.data<T>(), lda, // A (in/out)
282+
aux_B.data<T>(), lda, // B (in/out)
283+
0.0, 0.0, // VL, VU (not used)
284+
il, iu, // IL, IU
285+
Real(0.0), // ABSTOL
286+
found, // M (output)
287+
eigen_val, // W (output)
288+
eigen_vec, lda, // Z (output)
289+
&work_query, // WORK (query)
290+
lwork,
291+
&rwork_query, // RWORK (query)
292+
static_cast<int*>(nullptr), // IWORK (query)
293+
static_cast<int*>(nullptr), // IFAIL (query)
294+
info);
295+
296+
// !> If LWORK = -1, then a workspace query is assumed; the routine
297+
// !> only calculates the optimal size of the WORK array, returns
298+
// !> this value as the first entry of the WORK array.
299+
lwork = static_cast<int>(get_real(work_query));
300+
lwork = std::max(lwork, 1);
301+
302+
// work space
303+
Tensor work(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {lwork});
304+
work.zero();
305+
306+
const int lrwork = 7 * n;
307+
Tensor rwork(DataTypeToEnum<Real>::value, DeviceType::CpuDevice, {lrwork});
308+
rwork.zero();
309+
310+
const int liwork = 5 * n;
311+
Tensor iwork(DataType::DT_INT, DeviceType::CpuDevice, {liwork});
312+
iwork.zero();
313+
314+
std::vector<int> ifail(n);
315+
316+
lapackConnector::hegvx(
317+
itype, jobz, range, uplo,
318+
n,
319+
aux_A.data<T>(), lda, // A
320+
aux_B.data<T>(), lda, // B
321+
0.0, 0.0, // VL, VU
322+
il, iu, // IL, IU
323+
Real(0.0), // ABSTOL
324+
found, // M (output)
325+
eigen_val, // W
326+
eigen_vec, lda, // Z (output)
327+
work.data<T>(), // WORK
328+
lwork,
329+
rwork.data<Real>(), // RWORK
330+
iwork.data<int>(), // IWORK
331+
ifail.data(), // IFAIL
332+
info);
333+
334+
if (info < 0) {
335+
throw std::runtime_error("hegvx failed: illegal argument #" + std::to_string(-info));
336+
}
337+
if (info > 0) {
338+
throw std::runtime_error("hegvx failed to converge. Number of converged eigenvalues: " + std::to_string(info));
339+
}
340+
}
341+
};
317342

318343
template <typename T>
319344
struct lapack_getrf<T, DEVICE_CPU> {
@@ -400,6 +425,11 @@ template struct lapack_hegvd<double, DEVICE_CPU>;
400425
template struct lapack_hegvd<std::complex<float>, DEVICE_CPU>;
401426
template struct lapack_hegvd<std::complex<double>, DEVICE_CPU>;
402427

428+
template struct lapack_hegvx<float, DEVICE_CPU>;
429+
template struct lapack_hegvx<double, DEVICE_CPU>;
430+
template struct lapack_hegvx<std::complex<float>, DEVICE_CPU>;
431+
template struct lapack_hegvx<std::complex<double>, DEVICE_CPU>;
432+
403433
template struct lapack_getrf<float, DEVICE_CPU>;
404434
template struct lapack_getrf<double, DEVICE_CPU>;
405435
template struct lapack_getrf<std::complex<float>, DEVICE_CPU>;

source/source_base/module_container/ATen/kernels/lapack.h

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -115,38 +115,38 @@ struct lapack_hegvd {
115115
T *eigen_vec);
116116
};
117117

118-
// template <typename T, typename Device>
119-
// struct lapack_hegvx {
120-
// using Real = typename GetTypeReal<T>::type;
121-
// /**
122-
// * @ brief hegvx computes the first m eigenvalues and their corresponding eigenvectors of
123-
// * a complex generalized Hermitian-definite eigenproblem.
124-
// *
125-
// * In this op, the CPU version is implemented through the `hegvx` interface, and the CUDA version
126-
// * is implemented through the `evd` interface and acquires the first m eigenpairs
127-
// *
128-
// * hegvx 'V' 'I' 'U' is used to compute the first m eigenpairs of the problem
129-
// *
130-
// * @param n The order of the matrices A and B. n >= 0.
131-
// * @param lda The leading dimension of the array A and B. lda >= max(1, n).
132-
// * @param A On entry, the Hermitian matrix A. On exit, if info = 0, A contains the matrix Z of eigenvectors.
133-
// * @param B On entry, the Hermitian positive definite matrix B. On exit, the triangular factor from the Cholesky factorization of B.
134-
// * @param m The number of eigenvalues and eigenvectors to be found. 0 < m <= n.
135-
// * @param eigen_val The first m eigenvalues in ascending order.
136-
// * @param eigen_vec The first m columns contain the orthonormal eigenvectors of the matrix A corresponding to the selected eigenvalues.
137-
// *
138-
// * @note
139-
// * See LAPACK ZHEGVX doc for more details.
140-
// */
141-
// void operator()(
142-
// const int n,
143-
// const int lda,
144-
// T *A,
145-
// T *B,
146-
// const int m,
147-
// Real *eigen_val,
148-
// T *eigen_vec);
149-
// };
118+
template <typename T, typename Device>
119+
struct lapack_hegvx {
120+
using Real = typename GetTypeReal<T>::type;
121+
/**
122+
* @ brief hegvx computes the first m eigenvalues and their corresponding eigenvectors of
123+
* a complex generalized Hermitian-definite eigenproblem.
124+
*
125+
* In this op, the CPU version is implemented through the `hegvx` interface, and the CUDA version
126+
* is implemented through the `evd` interface and acquires the first m eigenpairs
127+
*
128+
* hegvx 'V' 'I' 'U' is used to compute the first m eigenpairs of the problem
129+
*
130+
* @param n The order of the matrices A and B. n >= 0.
131+
* @param lda The leading dimension of the array A and B. lda >= max(1, n).
132+
* @param A On entry, the Hermitian matrix A. On exit, if info = 0, A contains the matrix Z of eigenvectors.
133+
* @param B On entry, the Hermitian positive definite matrix B. On exit, the triangular factor from the Cholesky factorization of B.
134+
* @param m The number of eigenvalues and eigenvectors to be found. 0 < m <= n.
135+
* @param eigen_val The first m eigenvalues in ascending order.
136+
* @param eigen_vec The first m columns contain the orthonormal eigenvectors of the matrix A corresponding to the selected eigenvalues.
137+
*
138+
* @note
139+
* See LAPACK ZHEGVX doc for more details.
140+
*/
141+
void operator()(
142+
const int n,
143+
const int lda,
144+
T *Mat_A,
145+
T *Mat_B,
146+
const int m,
147+
Real *eigen_val,
148+
T *eigen_vec);
149+
};
150150

151151

152152
template <typename T, typename Device>

0 commit comments

Comments
 (0)