Skip to content

Commit 0342cc3

Browse files
committed
Rename assume_orthogonal to assume_S_orthogonal
1 parent 270b070 commit 0342cc3

File tree

4 files changed

+19
-15
lines changed

4 files changed

+19
-15
lines changed

source/source_hsolver/diago_cg.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -592,18 +592,22 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
592592
ct::Tensor psi_temp = psi.slice({0, 0}, {int(psi.shape().dim_size(0)), int(prec.shape().dim_size(0))});
593593
do
594594
{
595-
if (ntry > 0) // not the first try, then psi is orthogonalized
595+
// subspace diagonalization to get a better starting guess
596+
// for cg diagonalization, restart from current psi approximation
597+
// Note: if not the first try, then psi is already S-orthogonalized by CG iterations!
598+
// Otherwise, if the first try, then psi is not assumed to be S-orthogonalized
599+
if (ntry > 0)
596600
{
597601
ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp);
598-
const bool assume_orthogonal = true;
599-
this->subspace_func_(psi_temp, psi_map, assume_orthogonal);
602+
const bool assume_S_orthogonal = true;
603+
this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal);
600604
psi_temp.sync(psi_map);
601605
}
602606
else if (need_subspace_)
603607
{
604608
ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp);
605-
const bool assume_orthogonal = false;
606-
this->subspace_func_(psi_temp, psi_map, assume_orthogonal);
609+
const bool assume_S_orthogonal = false;
610+
this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal);
607611
psi_temp.sync(psi_map);
608612
}
609613

source/source_hsolver/diago_iter_assist.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
2424
Real* en, // [out] eigenvalues
2525
int n_band, // [in] number of bands to be calculated, also number of rows
2626
// of evc, if set to 0, n_band = nstart, default 0
27-
const bool is_orthogonal // [in] if true, psi is already orthogonalized
27+
const bool is_S_orthogonal // [in] if true, psi is already orthogonalized
2828
)
2929
{
3030
ModuleBase::TITLE("DiagoIterAssist", "diag_subspace");
@@ -47,7 +47,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
4747
setmem_complex_op()(hcc, 0, nstart * nstart);
4848

4949
// scc is overlap matrix, only needed when psi is not orthogonal
50-
if(!is_orthogonal){
50+
if(!is_S_orthogonal){
5151
resmem_complex_op()(scc, nstart * nstart, "DiagSub::scc");
5252
setmem_complex_op()(scc, 0, nstart * nstart);
5353
}
@@ -96,7 +96,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
9696
hcc,
9797
nstart);
9898

99-
if(!is_orthogonal){
99+
if(!is_S_orthogonal){
100100
// Only calculate S_sub if not orthogonal
101101
T *spsi = temp;
102102
// do sPsi for all bands
@@ -121,13 +121,13 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
121121
if (GlobalV::NPROC_IN_POOL > 1)
122122
{
123123
Parallel_Reduce::reduce_pool(hcc, nstart * nstart);
124-
if(!is_orthogonal){
124+
if(!is_S_orthogonal){
125125
Parallel_Reduce::reduce_pool(scc, nstart * nstart);
126126
}
127127
}
128128

129129
// after generation of H and (optionally) S matrix, diag them
130-
if (is_orthogonal) {
130+
if (is_S_orthogonal) {
131131
// Solve standard eigenproblem: H_sub * y = lambda * y
132132
DiagoIterAssist::diagH_LAPACK_standard(nstart, n_band, hcc, nstart, en, vcc);
133133
} else {
@@ -160,7 +160,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
160160
delmem_complex_op()(temp);
161161
}
162162
delmem_complex_op()(hcc);
163-
if(!is_orthogonal){
163+
if(!is_S_orthogonal){
164164
delmem_complex_op()(scc);
165165
}
166166
delmem_complex_op()(vcc);

source/source_hsolver/diago_iter_assist.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ class DiagoIterAssist
4242
* @param evc Output container for computed eigenvectors.
4343
* @param en Output array for computed eigenvalues.
4444
* @param n_band Number of bands (eigenvalues/eigenvectors) to compute. Default is 0 (all).
45-
* @param is_orthogonal If true, assumes the input wavefunction is already orthogonalized.
45+
* @param is_S_orthogonal If true, assumes the input wavefunction is already orthogonalized.
4646
*/
4747
static void diagH_subspace(const hamilt::Hamilt<T, Device>* const pHamilt,
4848
const psi::Psi<T, Device>& psi,
4949
psi::Psi<T, Device>& evc,
5050
Real *en,
5151
int n_band = 0,
52-
const bool is_orthogonal = false);
52+
const bool is_S_orthogonal = false);
5353

5454
/// @brief use LAPACK to diagonalize the Hamiltonian matrix
5555
/// @param pHamilt interface to hamiltonian

source/source_hsolver/hsolver_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
249249
// wrap the subspace_func into a lambda function
250250
// if assume_orthogonal is true, then solve standard eigenproblem
251251
// otherwise, solve generalized eigenproblem
252-
auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool assume_orthogonal) {
252+
auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool assume_S_orthogonal) {
253253
// psi_in should be a 2D tensor:
254254
// psi_in.shape() = [nbands, nbasis]
255255
const auto ndim = psi_in.shape().ndim();
@@ -269,7 +269,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
269269
ct::DeviceType::CpuDevice,
270270
ct::TensorShape({psi_in.shape().dim_size(0)}));
271271

272-
DiagoIterAssist<T, Device>::diagH_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>(), assume_orthogonal);
272+
DiagoIterAssist<T, Device>::diagH_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>(), assume_S_orthogonal);
273273
};
274274
DiagoCG<T, Device> cg(this->basis_type,
275275
this->calculation_type,

0 commit comments

Comments
 (0)