Skip to content

Commit 815614a

Browse files
authored
Merge pull request #1168 from dyzheng/develop
Fix: a error of mem_saver=1 with ks_solver=cg has been fixed
2 parents b05a96b + ced7970 commit 815614a

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

source/module_hamilt/ks_pw/operator_pw.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ class OperatorPW : public Operator
2121

2222
std::complex<double> *tmhpsi = this->get_hpsi(input);
2323
const std::complex<double> *tmpsi_in = std::get<0>(psi_info);
24+
//if range in hpsi_info is illegal, the first return of to_range() would be nullptr
25+
if(tmpsi_in == nullptr)
26+
{
27+
ModuleBase::WARNING_QUIT("OperatorPW", "please choose correct range of psi for hPsi()!");
28+
}
2429

2530
this->act(std::get<0>(input), n_npwx, tmpsi_in, tmhpsi);
2631
OperatorPW* node((OperatorPW*)this->next_op);

source/module_hsolver/diago_cg.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ DiagoCG::~DiagoCG()
2424

2525
void DiagoCG::diag_mock(hamilt::Hamilt *phm_in, psi::Psi<std::complex<double>> &phi, double *eigenvalue_in)
2626
{
27-
if (test_cg == 1)
28-
ModuleBase::TITLE("DiagoCG", "diag_once");
27+
ModuleBase::TITLE("DiagoCG", "diag_once");
2928
ModuleBase::timer::tick("DiagoCG", "diag_once");
3029

3130
/// out : record for states of convergence

source/module_psi/psi.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,21 @@ class Psi
236236
// solve Range: return(pointer of begin, number of bands or k-points)
237237
std::tuple<const T*, int> to_range(const Range& range)const
238238
{
239-
if(range.k_first != this->k_first || range.index_1<0 || range.range_1<0 || range.range_2<range.range_1
239+
int index_1_in = range.index_1;
240+
//mem_saver=1 case, only k==0 memory space is avaliable
241+
if(index_1_in>0 & this->nk == 1)
242+
{
243+
index_1_in = 0;
244+
}
245+
if(range.k_first != this->k_first || index_1_in<0 || range.range_1<0 || range.range_2<range.range_1
240246
|| (range.k_first && range.range_2>=this->nbands)
241247
|| (!range.k_first && (range.range_2>=this->nk || range.index_1>=this->nbands) ) )
242248
{
243249
return std::tuple<const T*, int>(nullptr, 0);
244250
}
245251
else
246252
{
247-
const T* p = &this->psi[(range.index_1 * this->nbands + range.range_1) * this->nbasis];
253+
const T* p = &this->psi[(index_1_in * this->nbands + range.range_1) * this->nbasis];
248254
int m = (range.range_2 - range.range_1 + 1)* this->npol;
249255
return std::tuple<const T*, int>(p, m);
250256
}

0 commit comments

Comments
 (0)