Skip to content

Commit 5e7cf0d

Browse files
committed
Update lapack_hegvd interface to support lda different than n
1 parent 4c72902 commit 5e7cf0d

File tree

7 files changed

+452
-85
lines changed

7 files changed

+452
-85
lines changed

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,37 @@ template <typename T>
105105
struct lapack_hegvd<T, DEVICE_GPU> {
106106
using Real = typename GetTypeReal<T>::type;
107107
void operator()(
108-
const int& itype,
109-
const char& jobz,
110-
const char& uplo,
108+
const int dim,
109+
const int lda,
111110
T* Mat_A,
112111
T* Mat_B,
113-
const int& dim,
114-
Real* eigen_val)
112+
Real* eigen_val,
113+
T *eigen_vec)
115114
{
116-
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val);
115+
const int itype = 1;
116+
const char jobz = 'V';
117+
const char uplo = 'U';
118+
cudaErrcheck(cudaMemcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda, cudaMemcpyDeviceToDevice));
119+
cuSolverConnector::hegvd(cusolver_handle, itype, jobz, uplo, dim, eigen_vec, lda, Mat_B, lda, eigen_val);
117120
}
118121
};
119122

123+
// template <typename T>
124+
// struct lapack_hegvx<T, DEVICE_GPU> {
125+
// using Real = typename GetTypeReal<T>::type;
126+
// void operator()(
127+
// const int n,
128+
// const int lda,
129+
// T *A,
130+
// T *B,
131+
// const int m,
132+
// Real *eigen_val,
133+
// T *eigen_vec)
134+
// {
135+
// cuSolverConnector::hegvx(cusolver_handle, n, lda, A, B, m, eigen_val, eigen_vec);
136+
// }
137+
// };
138+
//
120139
template <typename T>
121140
struct lapack_getrf<T, DEVICE_GPU> {
122141
void operator()(

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
#include <base/third_party/lapack.h>
44

5+
// #include <cstring> // std::memcpy
6+
#include <algorithm> // std::copy
7+
58
namespace container {
69
namespace kernels {
710

@@ -36,7 +39,7 @@ struct lapack_trtri<T, DEVICE_CPU> {
3639
const char& diag,
3740
const int& dim,
3841
T* Mat,
39-
const int& lda)
42+
const int& lda)
4043
{
4144
int info = 0;
4245
lapackConnector::trtri(uplo, diag, dim, Mat, lda, info);
@@ -51,8 +54,8 @@ struct lapack_potrf<T, DEVICE_CPU> {
5154
void operator()(
5255
const char& uplo,
5356
const int& dim,
54-
T* Mat,
55-
const int& lda)
57+
T* Mat,
58+
const int& lda)
5659
{
5760
int info = 0;
5861
lapackConnector::potrf(uplo, dim, Mat, dim, info);
@@ -85,7 +88,7 @@ struct lapack_heevd<T, DEVICE_CPU> {
8588
Tensor iwork(DataTypeToEnum<int>::value, DeviceType::CpuDevice, {liwork});
8689
iwork.zero();
8790

88-
lapackConnector::heevd(jobz, uplo, dim, Mat, dim, eigen_val, work.data<T>(), lwork, rwork.data<Real>(), lrwork, iwork.data<int>(), liwork, info);
91+
lapackConnector::heevd(jobz, uplo, dim, Mat, dim, eigen_val, work.data<T>(), lwork, rwork.data<Real>(), lrwork, iwork.data<int>(), liwork, info);
8992
if (info != 0) {
9093
throw std::runtime_error("heevd failed with info = " + std::to_string(info));
9194
}
@@ -96,14 +99,26 @@ template <typename T>
9699
struct lapack_hegvd<T, DEVICE_CPU> {
97100
using Real = typename GetTypeReal<T>::type;
98101
void operator()(
99-
const int& itype,
100-
const char& jobz,
101-
const char& uplo,
102-
T* Mat_A,
103-
T* Mat_B,
104-
const int& dim,
105-
Real* eigen_val)
102+
const int dim,
103+
const int lda,
104+
T *Mat_A,
105+
T *Mat_B,
106+
Real *eigen_val,
107+
T *eigen_vec)
106108
{
109+
// first copy Mat_A to eigen_vec
110+
// then pass as argument "A" in lapack hegvd
111+
// and this block of memory will be overwritten by eigenvectors
112+
// for (int i = 0; i < dim * lda; ++i){
113+
// eigen_vec[i] = Mat_A[i];
114+
// }
115+
// std::memcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda);
116+
// eigen_vec = Mat_A
117+
std::copy(Mat_A, Mat_A + dim*lda, eigen_vec);
118+
119+
const int itype = 1;
120+
const char jobz = 'V';
121+
const char uplo = 'U';
107122
int info = 0;
108123
int lwork = std::max(2 * dim + dim * dim, 1 + 6 * dim + 2 * dim * dim);
109124
Tensor work(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {lwork});
@@ -117,13 +132,90 @@ struct lapack_hegvd<T, DEVICE_CPU> {
117132
Tensor iwork(DataType::DT_INT, DeviceType::CpuDevice, {liwork});
118133
iwork.zero();
119134

120-
lapackConnector::hegvd(itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val, work.data<T>(), lwork, rwork.data<Real>(), lrwork, iwork.data<int>(), liwork, info);
135+
// After this, eigen_vec will contain the matrix Z of eigenvectors
136+
lapackConnector::hegvd(itype, jobz, uplo, dim, eigen_vec, lda, Mat_B, lda, eigen_val, work.data<T>(), lwork, rwork.data<Real>(), lrwork, iwork.data<int>(), liwork, info);
121137
if (info != 0) {
122138
throw std::runtime_error("hegvd failed with info = " + std::to_string(info));
123139
}
124140
}
125141
};
126142

143+
144+
// template <typename T>
145+
// struct lapack_hegvx<T, DEVICE_CPU> {
146+
// using Real = typename GetTypeReal<T>::type;
147+
// void operator()(
148+
// const int n,
149+
// const int lda,
150+
// T *A,
151+
// T *B,
152+
// const int m,
153+
// Real *eigen_val,
154+
// T *eigen_vec)
155+
// {
156+
// int info = 0;
157+
158+
// int mm = m;
159+
160+
// int lwork = -1;
161+
162+
// T *work = new T[1];
163+
// Real *rwork = new Real[7 * n];
164+
// int *iwork = new int[5 * n];
165+
// int *ifail = new int[n];
166+
167+
// // set lwork = -1 to query optimal work size
168+
// lapackConnector::hegvx(1, // ITYPE = 1: A*x = (lambda)*B*x
169+
// 'V', 'I', 'U',
170+
// n, A, lda, B, lda,
171+
// 0.0, 0.0,
172+
// 1, m, 0.0, mm,
173+
// eigen_val, eigen_vec, lda,
174+
// work,
175+
// lwork, // lwork = 1, query optimal size.
176+
// rwork, iwork, ifail,
177+
// info);
178+
179+
// // !> If LWORK = -1, then a workspace query is assumed; the routine
180+
// // !> only calculates the optimal size of the WORK array, returns
181+
// // !> this value as the first entry of the WORK array.
182+
// lwork = int(get_real(work[0]));
183+
// delete[] work;
184+
// work = new T[lwork];
185+
186+
// lapackConnector::hegvx(
187+
// 1, // ITYPE = 1: A*x = (lambda)*B*x
188+
// 'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
189+
// 'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
190+
// 'U', // UPLO = 'U': Upper triangles of A and B are stored.
191+
// n, // order of the matrices A and B.
192+
// A, // A is COMPLEX*16 array dimension (LDA, N)
193+
// lda, // leading dimension of the array A.
194+
// B, // B is COMPLEX*16 array, dimension (LDB, N)
195+
// lda, // assume that leading dimension of B is the same as A.
196+
// 0.0, // VL, Not referenced if RANGE = 'A' or 'I'.
197+
// 0.0, // VU, Not referenced if RANGE = 'A' or 'I'.
198+
// 1, // IL: If RANGE='I', the index of the smallest eigenvalue to be returned. 1 <= IL <= IU <= N,
199+
// m, // IU: If RANGE='I', the index of the largest eigenvalue to be returned. 1 <= IL <= IU <= N,
200+
// 0.0, // ABSTOL
201+
// mm, // M: The total number of eigenvalues found. 0 <= M <= N. if RANGE = 'I', M = IU-IL+1.
202+
// eigen_val, // W store eigenvalues
203+
// eigen_vec, // Z store eigenvector
204+
// lda, // LDZ: The leading dimension of the array Z.
205+
// work,
206+
// lwork,
207+
// rwork,
208+
// iwork,
209+
// ifail,
210+
// info);
211+
212+
// delete[] work;
213+
// delete[] rwork;
214+
// delete[] iwork;
215+
// delete[] ifail;
216+
// }
217+
// };
218+
127219
template <typename T>
128220
struct lapack_getrf<T, DEVICE_CPU> {
129221
void operator()(
@@ -220,4 +312,4 @@ template struct lapack_getrs<std::complex<float>, DEVICE_CPU>;
220312
template struct lapack_getrs<std::complex<double>, DEVICE_CPU>;
221313

222314
} // namespace kernels
223-
} // namespace container
315+
} // namespace container

source/source_base/module_container/ATen/kernels/lapack.h

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct lapack_potrf {
3535
void operator()(
3636
const char& uplo,
3737
const int& dim,
38-
T* Mat,
38+
T* Mat,
3939
const int& lda);
4040
};
4141

@@ -55,16 +55,64 @@ struct lapack_heevd {
5555
template <typename T, typename Device>
5656
struct lapack_hegvd {
5757
using Real = typename GetTypeReal<T>::type;
58+
/**
59+
* @brief Computes all the eigenvalues and, optionally, the eigenvectors of a complex generalized Hermitian-definite eigenproblem.
60+
*
61+
* This function solves the problem A*x = lambda*B*x, where A and B are Hermitian matrices, and B is also positive definite.
62+
*
63+
* @param dim The order of the matrices Mat_A and Mat_B. dim >= 0.
64+
* @param lda The leading dimension of the arrays Mat_A and Mat_B. lda >= max(1, dim).
65+
* @param Mat_A On entry, the Hermitian matrix A. On exit, it may be overwritten.
66+
* @param Mat_B On entry, the Hermitian positive definite matrix B. On exit, it may be overwritten.
67+
* @param eigen_val Array to store the computed eigenvalues in ascending order.
68+
* @param eigen_vec If not nullptr, array to store the computed eigenvectors.
69+
*
70+
* @note
71+
* See LAPACK ZHEGVD or CHEGVD documentation for more details.
72+
* This function assumes that A and B have the same leading dimensions, lda.
73+
*/
5874
void operator()(
59-
const int& itype,
60-
const char& jobz,
61-
const char& uplo,
62-
T* Mat_A,
63-
T* Mat_B,
64-
const int& dim,
65-
Real* eigen_val);
75+
const int dim,
76+
const int lda,
77+
T *Mat_A,
78+
T *Mat_B,
79+
Real *eigen_val,
80+
T *eigen_vec);
6681
};
6782

83+
// template <typename T, typename Device>
84+
// struct lapack_hegvx {
85+
// using Real = typename GetTypeReal<T>::type;
86+
// /**
87+
// * @ brief hegvx computes the first m eigenvalues and their corresponding eigenvectors of
88+
// * a complex generalized Hermitian-definite eigenproblem.
89+
// *
90+
// * In this op, the CPU version is implemented through the `hegvx` interface, and the CUDA version
91+
// * is implemented through the `evd` interface and acquires the first m eigenpairs
92+
// *
93+
// * hegvx 'V' 'I' 'U' is used to compute the first m eigenpairs of the problem
94+
// *
95+
// * @param n The order of the matrices A and B. n >= 0.
96+
// * @param lda The leading dimension of the array A and B. lda >= max(1, n).
97+
// * @param A On entry, the Hermitian matrix A. On exit, if info = 0, A contains the matrix Z of eigenvectors.
98+
// * @param B On entry, the Hermitian positive definite matrix B. On exit, the triangular factor from the Cholesky factorization of B.
99+
// * @param m The number of eigenvalues and eigenvectors to be found. 0 < m <= n.
100+
// * @param eigen_val The first m eigenvalues in ascending order.
101+
// * @param eigen_vec The first m columns contain the orthonormal eigenvectors of the matrix A corresponding to the selected eigenvalues.
102+
// *
103+
// * @note
104+
// * See LAPACK ZHEGVX doc for more details.
105+
// */
106+
// void operator()(
107+
// const int n,
108+
// const int lda,
109+
// T *A,
110+
// T *B,
111+
// const int m,
112+
// Real *eigen_val,
113+
// T *eigen_vec);
114+
// };
115+
68116

69117
template <typename T, typename Device>
70118
struct lapack_getrf {
@@ -110,4 +158,4 @@ void destroyGpuSolverHandle(); // destroy cusolver handle
110158
} // namespace container
111159
} // namespace kernels
112160

113-
#endif // ATEN_KERNELS_LAPACK_H_
161+
#endif // ATEN_KERNELS_LAPACK_H_

source/source_base/module_container/ATen/kernels/rocm/lapack.hip.cu

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ void destroyGpuSolverHandle() {
2828
template <typename T>
2929
__global__ void set_matrix_kernel(
3030
const char uplo,
31-
T* A,
32-
const int dim)
31+
T* A,
32+
const int dim)
3333
{
3434
int bid = blockIdx.x;
3535
int tid = threadIdx.x;
@@ -64,7 +64,7 @@ struct lapack_trtri<T, DEVICE_GPU> {
6464
const char& diag,
6565
const int& dim,
6666
T* Mat,
67-
const int& lda)
67+
const int& lda)
6868
{
6969
// TODO: trtri is not implemented in this method yet
7070
// Cause the trtri in cuSolver is not stable for ABACUS!
@@ -82,8 +82,8 @@ struct lapack_potrf<T, DEVICE_GPU> {
8282
void operator()(
8383
const char& uplo,
8484
const int& dim,
85-
T* Mat,
86-
const int& lda)
85+
T* Mat,
86+
const int& lda)
8787
{
8888
// hipSolverConnector::potrf(hipsolver_handle, uplo, dim, Mat, dim);
8989
std::vector<T> H_Mat(dim * dim, static_cast<T>(0.0));
@@ -118,15 +118,22 @@ template <typename T>
118118
struct lapack_hegvd<T, DEVICE_GPU> {
119119
using Real = typename GetTypeReal<T>::type;
120120
void operator()(
121+
const int dim,
122+
const int lda,
121123
const int& itype,
122124
const char& jobz,
123125
const char& uplo,
124126
T* Mat_A,
125127
T* Mat_B,
126128
const int& dim,
127-
Real* eigen_val)
129+
Real* eigen_val,
130+
T *eigen_vec)
128131
{
129-
hipSolverConnector::hegvd(hipsolver_handle, itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val);
132+
const int itype = 1;
133+
const char jobz = 'V';
134+
const char uplo = 'U';
135+
hipErrcheck(hipMemcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda, hipMemcpyDeviceToDevice));
136+
hipSolverConnector::hegvd(hipsolver_handle, itype, jobz, uplo, dim, Mat_A, lda, Mat_B, lda, eigen_val);
130137
}
131138
};
132139

@@ -156,4 +163,4 @@ template struct lapack_hegvd<std::complex<float>, DEVICE_GPU>;
156163
template struct lapack_hegvd<std::complex<double>, DEVICE_GPU>;
157164

158165
} // namespace kernels
159-
} // namespace container
166+
} // namespace container

0 commit comments

Comments
 (0)