Skip to content

Commit 39bbfe4

Browse files
committed
Replace hegvd_op with lapack_hegvd in diago_dav_subspace
1 parent 4b8e214 commit 39bbfe4

File tree

3 files changed

+36
-24
lines changed

3 files changed

+36
-24
lines changed

source/source_hsolver/diago_dav_subspace.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44

55
#include "source_base/module_device/device.h"
66
#include "source_base/timer.h"
7-
#include "source_hsolver/kernels/hegvd_op.h"
87
#include "source_base/kernels/math_kernel_op.h"
9-
#include "source_hsolver/kernels/bpcg_kernel_op.h" // normalize_op, precondition_op, apply_eigenvalues_op
108
#include "source_base/kernels/dsp/dsp_connector.h"
9+
// #include "source_base/module_container/ATen/kernels/lapack.h"
10+
11+
#include <ATen/kernels/lapack.h>
1112

13+
#include "source_hsolver/kernels/hegvd_op.h"
1214
#include "source_hsolver/diag_hs_para.h"
15+
#include "source_hsolver/kernels/bpcg_kernel_op.h" // normalize_op, precondition_op, apply_eigenvalues_op
1316

1417
#include <vector>
1518

@@ -540,7 +543,7 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
540543
if (this->diag_comm.rank == 0)
541544
{
542545
syncmem_complex_op()(this->d_scc, scc, nbase * this->nbase_x);
543-
hegvd_op<T, Device>()(this->ctx, nbase, this->nbase_x, this->hcc, this->d_scc, this->d_eigenvalue, this->vcc);
546+
ct::kernels::lapack_hegvd<T, ct_Device>()(nbase, this->nbase_x, this->hcc, this->d_scc, this->d_eigenvalue, this->vcc);
544547
syncmem_var_d2h_op()((*eigenvalue_iter).data(), this->d_eigenvalue, this->nbase_x);
545548
}
546549
#endif

source/source_hsolver/diago_dav_subspace.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "source_base/module_device/device.h" // base_device
66
#include "source_base/module_device/memory_op.h"// base_device::memory"
77

8+
#include "source_base/module_container/ATen/kernels/lapack.h"
9+
810
#include "source_hsolver/diag_comm_info.h"
911
#include "source_hsolver/diag_const_nums.h"
1012

@@ -189,10 +191,14 @@ class Diago_DavSubspace
189191
using syncmem_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;
190192
using syncmem_d2h_op = base_device::memory::synchronize_memory_op<T, base_device::DEVICE_CPU, Device>;
191193

194+
// Note that ct_Device is different from base_device!
195+
using ct_Device = typename ct::PsiToContainer<Device>::type;
196+
// using hegvd_op = container::kernels::lapack_hegvd<T, ct_Device>;
197+
192198
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;
193199
const T one_ = static_cast<T>(1.0), zero_ = static_cast<T>(0.0), neg_one_ = static_cast<T>(-1.0);
194200
};
195201

196202
} // namespace hsolver
197203

198-
#endif
204+
#endif

source/source_hsolver/diago_david.h

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
#include "source_base/module_device/device.h" // base_device
66
#include "source_base/module_device/memory_op.h"// base_device::memory
77

8+
// #include "source_base/module_container/ATen/kernels/lapack.h" // container::kernels
9+
810
#include "source_hsolver/diag_comm_info.h"
11+
#include "source_hsolver/kernels/hegvd_op.h"
912

1013
#include <vector>
1114
#include <functional>
@@ -26,16 +29,16 @@ template <typename T = std::complex<double>, typename Device = base_device::DEVI
2629
class DiagoDavid
2730
{
2831
private:
29-
// Note GetTypeReal<T>::type will
30-
// return T if T is real type(float, double),
32+
// Note GetTypeReal<T>::type will
33+
// return T if T is real type(float, double),
3134
// otherwise return the real type of T(complex<float>, std::complex<double>)
3235
using Real = typename GetTypeReal<T>::type;
33-
36+
3437
public:
3538

3639
/**
3740
* @brief Constructor for the DiagoDavid class.
38-
*
41+
*
3942
* @param[in] precondition_in Pointer to the preconditioning matrix.
4043
* @param[in] nband_in Number of eigenpairs required(i.e. bands).
4144
* @param[in] dim_in Dimension of the matrix.
@@ -44,10 +47,10 @@ class DiagoDavid
4447
* the reduced basis set before \b restart of Davidson.
4548
* @param[in] use_paw_in Flag indicating whether to use PAW.
4649
* @param[in] diag_comm_in Communication information for diagonalization.
47-
*
50+
*
4851
* @tparam T The data type of the matrices and arrays.
4952
* @tparam Device The device type (base_device::DEVICE_CPU or DEVICE_GPU).
50-
*
53+
*
5154
* @note Auxiliary memory is allocated in the constructor and deallocated in the destructor.
5255
*/
5356
DiagoDavid(const Real* precondition_in,
@@ -59,10 +62,10 @@ class DiagoDavid
5962

6063
/**
6164
* @brief Destructor for the DiagoDavid class.
62-
*
65+
*
6366
* This destructor releases the dynamically allocated memory used by the class members.
6467
* It deletes the basis, hpsi, spsi, hcc, vcc, lagrange_matrix, and eigenvalue arrays.
65-
*
68+
*
6669
*/
6770
~DiagoDavid();
6871

@@ -75,7 +78,7 @@ class DiagoDavid
7578
* This function type is used to define a matrix-blockvector operator H.
7679
* For eigenvalue problem HX = λX or generalized eigenvalue problem HX = λSX,
7780
* this function computes the product of the Hamiltonian matrix H and a blockvector X.
78-
*
81+
*
7982
* Called as follows:
8083
* hpsi(X, HX, ld, nvec) where X and HX are (ld, nvec)-shaped blockvectors.
8184
* Result HX = H * X is stored in HX.
@@ -84,15 +87,15 @@ class DiagoDavid
8487
* @param[in] HX Head address of output blockvector of type `T*`.
8588
* @param[in] ld Leading dimension of blockvector.
8689
* @param[in] nvec Number of vectors in a block.
87-
*
90+
*
8891
* @warning X and HX are the exact address to read input X and store output H*X,
8992
* @warning both of size ld * nvec.
9093
*/
9194
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
9295

9396
/**
9497
* @brief A function type representing the SX function.
95-
*
98+
*
9699
* nrow is leading dimension of spsi, npw is leading dimension of psi, nbands is number of vecs
97100
*
98101
* This function type is used to define a matrix-blockvector operator S.
@@ -108,9 +111,9 @@ class DiagoDavid
108111

109112
/**
110113
* @brief Performs iterative diagonalization using the David algorithm.
111-
*
114+
*
112115
* @warning Please see docs of `HPsiFunc` for more information about the hpsi mat-vec interface.
113-
*
116+
*
114117
* @tparam T The type of the elements in the matrix.
115118
* @tparam Device The device type (CPU or GPU).
116119
* @param hpsi_func The function object that computes the matrix-blockvector product H * psi.
@@ -123,13 +126,13 @@ class DiagoDavid
123126
* @param ntry_max The maximum number of attempts for the diagonalization restart.
124127
* @param notconv_max The maximum number of bands unconverged allowed.
125128
* @return The total number of iterations performed during the diagonalization.
126-
*
129+
*
127130
* @note ntry_max is an empirical parameter that should be specified in external routine, default 5
128131
* notconv_max is determined by the accuracy required for the calculation, default 0
129132
*/
130133
int diag(
131-
const HPsiFunc& hpsi_func, // function void hpsi(T*, T*, const int, const int)
132-
const SPsiFunc& spsi_func, // function void spsi(T*, T*, const int, const int, const int)
134+
const HPsiFunc& hpsi_func, // function void hpsi(T*, T*, const int, const int)
135+
const SPsiFunc& spsi_func, // function void spsi(T*, T*, const int, const int, const int)
133136
const int ld_psi, // Leading dimension of the psi input
134137
T *psi_in, // Pointer to eigenvectors
135138
Real* eigenvalue_in, // Pointer to store the resulting eigenvalues
@@ -218,7 +221,7 @@ class DiagoDavid
218221

219222
/**
220223
* Calculates the elements of the diagonalization matrix for the DiagoDavid class.
221-
*
224+
*
222225
* @param dim The dimension of the problem.
223226
* @param nbase The current dimension of the reduced basis.
224227
* @param nbase_x The maximum dimension of the reduced basis set.
@@ -237,7 +240,7 @@ class DiagoDavid
237240

238241
/**
239242
* Refreshes the diagonalization solver by updating the basis and the reduced Hamiltonian.
240-
*
243+
*
241244
* @param dim The dimension of the problem.
242245
* @param nband The number of bands.
243246
* @param nbase The number of basis states.
@@ -249,7 +252,7 @@ class DiagoDavid
249252
* @param spsi Pointer to the output array for the updated basis set (nband-th column).
250253
* @param hcc Pointer to the output array for the updated reduced Hamiltonian.
251254
* @param vcc Pointer to the output array for the updated eigenvector matrix.
252-
*
255+
*
253256
*/
254257
void refresh(const int& dim,
255258
const int& nband,
@@ -286,7 +289,7 @@ class DiagoDavid
286289

287290
/**
288291
* @brief Plans the Schmidt orthogonalization for a given number of bands.
289-
*
292+
*
290293
* @tparam T The type of the elements in the vectors.
291294
* @tparam Device The device on which the computation will be performed.
292295
* @param nband The number of bands.

0 commit comments

Comments
 (0)