Skip to content

Commit 831625e

Browse files
committed
Update heevd interface to add lda
1 parent 7fa2e21 commit 831625e

File tree

5 files changed

+52
-20
lines changed

5 files changed

+52
-20
lines changed

source/source_base/module_container/ATen/kernels/blas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace kernels {
1111

1212
template <typename T, typename Device>
1313
struct blas_copy {
14+
// DCOPY copies a vector, x, to a vector, y.
1415
void operator()(
1516
const int n,
1617
const T *x,

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,14 @@ template <typename T>
231231
struct lapack_heevd<T, DEVICE_GPU> {
232232
using Real = typename GetTypeReal<T>::type;
233233
void operator()(
234-
const char& jobz,
235-
const char& uplo,
234+
const int dim,
236235
T* Mat,
237-
const int& dim,
236+
const int lda,
238237
Real* eigen_val)
239238
{
240-
cuSolverConnector::heevd(cusolver_handle, jobz, uplo, dim, Mat, dim, eigen_val);
239+
char jobz = 'V'; // Compute eigenvalues and eigenvectors
240+
char uplo = 'U';
241+
cuSolverConnector::heevd(cusolver_handle, jobz, uplo, dim, Mat, lda, eigen_val);
241242
}
242243
};
243244

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,13 @@ template <typename T>
196196
struct lapack_heevd<T, DEVICE_CPU> {
197197
using Real = typename GetTypeReal<T>::type;
198198
void operator()(
199-
const char& jobz,
200-
const char& uplo,
199+
const int dim,
201200
T* Mat,
202-
const int& dim,
201+
const int lda,
203202
Real* eigen_val)
204203
{
204+
char jobz = 'V'; // Compute eigenvalues and eigenvectors
205+
char uplo = 'U';
205206
int info = 0;
206207
int lwork = std::max(2 * dim + dim * dim, 1 + 6 * dim + 2 * dim * dim);
207208
Tensor work(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {lwork});
@@ -215,7 +216,7 @@ struct lapack_heevd<T, DEVICE_CPU> {
215216
Tensor iwork(DataTypeToEnum<int>::value, DeviceType::CpuDevice, {liwork});
216217
iwork.zero();
217218

218-
lapackConnector::heevd(jobz, uplo, dim, Mat, dim, eigen_val, work.data<T>(), lwork, rwork.data<Real>(), lrwork, iwork.data<int>(), liwork, info);
219+
lapackConnector::heevd(jobz, uplo, dim, Mat, lda, eigen_val, work.data<T>(), lwork, rwork.data<Real>(), lrwork, iwork.data<int>(), liwork, info);
219220
if (info != 0) {
220221
throw std::runtime_error("heevd failed with info = " + std::to_string(info));
221222
}
@@ -233,6 +234,8 @@ struct lapack_heevx<T, DEVICE_CPU> {
233234
Real *eigen_val,
234235
T *eigen_vec)
235236
{
237+
// copy Mat to aux, solve heevx(aux, eigen_val, eigen_vec)
238+
// input Mat is not referenced in actual heevx LAPACK routines, and aux is destroyed.
236239
Tensor aux(DataTypeToEnum<T>::value, DeviceType::CpuDevice, {n * lda});
237240
// Copy Mat to aux since heevx will destroy it
238241
// aux = Mat

source/source_base/module_container/ATen/kernels/lapack.h

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,37 @@ struct lapack_getrs {
145145
// ============================================================================
146146
template <typename T, typename Device>
147147
struct lapack_heevd {
148+
// !> ZHEEVD computes all eigenvalues and, optionally, eigenvectors of a
149+
// !> complex Hermitian matrix A. If eigenvectors are desired, it uses a
150+
// !> divide and conquer algorithm.
151+
// !> On exit, if JOBZ = 'V', then if INFO = 0, A contains the
152+
// !> orthonormal eigenvectors of the matrix A.
153+
/**
154+
* @brief Computes all eigenvalues and, optionally, eigenvectors of a complex Hermitian matrix.
155+
*
156+
* This function solves the standard Hermitian eigenvalue problem A*x = lambda*x,
157+
* where A is a Hermitian matrix. It computes all eigenvalues and optionally
158+
* the corresponding eigenvectors using a divide and conquer algorithm.
159+
*
160+
* @param[in] dim The order of the matrix A. dim >= 0.
161+
* @param[in,out] Mat On entry, the Hermitian matrix A.
162+
* On exit, if eigenvectors are computed, A contains the
163+
* orthonormal eigenvectors of the matrix A.
164+
* @param[in] lda The leading dimension of the array Mat. lda >= max(1, dim).
165+
* @param[out] eigen_val Array of size at least dim. On normal exit, contains the
166+
* eigenvalues in ascending order.
167+
*
168+
* @note
169+
* See LAPACK ZHEEVD or CHEEVD documentation for more details.
170+
* The matrix is assumed to be stored in upper or lower triangular form
171+
* according to the uplo parameter (not shown here but typically passed
172+
* to the actual implementation).
173+
*/
148174
using Real = typename GetTypeReal<T>::type;
149175
void operator()(
150-
const char& jobz,
151-
const char& uplo,
176+
const int dim,
152177
T* Mat,
153-
const int& dim,
178+
const int lda,
154179
Real* eigen_val);
155180
};
156181

@@ -165,7 +190,8 @@ struct lapack_heevx {
165190
*
166191
* @param dim The order of the matrix A. dim >= 0.
167192
* @param lda The leading dimension of the array Mat. lda >= max(1, dim).
168-
* @param Mat On entry, the Hermitian matrix A. On exit, A is kept.
193+
* @param[in] Mat On entry, the Hermitian matrix A. On exit, A is kept.
194+
* Only used to provide values of matrix.
169195
* @param neig The number of eigenvalues to be found. 0 <= neig <= dim.
170196
* @param eigen_val On normal exit, the first \p neig elements contain the selected
171197
* eigenvalues in ascending order.

source/source_hsolver/diago_bpcg.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,14 @@ void DiagoBPCG<T, Device>::line_minimize(
112112
// Finally, the last two!
113113
template<typename T, typename Device>
114114
void DiagoBPCG<T, Device>::orth_cholesky(
115-
ct::Tensor& workspace_in,
116-
ct::Tensor& psi_out,
117-
ct::Tensor& hpsi_out,
115+
ct::Tensor& workspace_in,
116+
ct::Tensor& psi_out,
117+
ct::Tensor& hpsi_out,
118118
ct::Tensor& hsub_out)
119119
{
120120
// gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band)
121121
this->pmmcn.multiply(1.0, psi_out.data<T>(), psi_out.data<T>(), 0.0, hsub_out.data<T>());
122-
122+
123123
// set hsub matrix to lower format;
124124
ct::kernels::set_matrix<T, ct_Device>()(
125125
'L', hsub_out.data<T>(), this->n_band);
@@ -209,7 +209,8 @@ void DiagoBPCG<T, Device>::diag_hsub(
209209
// gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
210210
this->pmmcn.multiply(1.0, hpsi_in.data<T>(), psi_in.data<T>(), 0.0, hsub_out.data<T>());
211211

212-
ct::kernels::lapack_heevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
212+
// ct::kernels::lapack_heevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
213+
ct::kernels::lapack_heevd<T, ct_Device>()(this->n_band, hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
213214

214215
return;
215216
}
@@ -235,15 +236,15 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block(
235236
// hpsi_out[n_basis, n_band] = psi_out[n_basis, n_band] x hsub_out[n_band, n_band]
236237
this->rotate_wf(hsub_out, psi_out, workspace_in);
237238
this->rotate_wf(hsub_out, hpsi_out, workspace_in);
238-
239+
239240
return;
240241
}
241242

242243
template<typename T, typename Device>
243244
void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(
244-
ct::Tensor& psi_out,
245+
ct::Tensor& psi_out,
245246
ct::Tensor& hpsi_out,
246-
ct::Tensor& hsub_out,
247+
ct::Tensor& hsub_out,
247248
ct::Tensor& workspace_in,
248249
ct::Tensor& eigenvalue_out)
249250
{

0 commit comments

Comments
 (0)