Skip to content

Commit 1a548f6

Browse files
committed
Refactor: modified Operator::hPsi(), hpsi memory arranged outside
1 parent aa8dde3 commit 1a548f6

File tree

7 files changed

+101
-60
lines changed

7 files changed

+101
-60
lines changed

source/module_esolver/esolver_sdft_pw_tool.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,18 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
246246
// this->phami->hPsi(j1psi.get_pointer(), j2psi.get_pointer(), ndim*totbands_per*npwx);
247247
// this->phami->hPsi(j1sfpsi.get_pointer(), j2sfpsi.get_pointer(), ndim*totbands_per*npwx);
248248
psi::Range allbands(1,0,0,totbands_per-1);
249-
hamilt::Operator<std::complex<double>>::hpsi_info info_psi0(&psi0, allbands);
250-
const std::complex<double>* hpsi_out = std::get<0>(this->phami->ops->hPsi(info_psi0))->get_pointer();
251-
ModuleBase::GlobalFunc::COPYARRAY(hpsi_out, hpsi0.get_pointer(), totbands_per*npwx);
249+
hamilt::Operator<std::complex<double>>::hpsi_info info_psi0(&psi0, allbands, hpsi0.get_pointer());
250+
this->phami->ops->hPsi(info_psi0);
252251

253-
hamilt::Operator<std::complex<double>>::hpsi_info info_sfpsi0(&sfpsi0, allbands);
254-
const std::complex<double>* hsfpsi_out = std::get<0>(this->phami->ops->hPsi(info_sfpsi0))->get_pointer();
255-
ModuleBase::GlobalFunc::COPYARRAY(hsfpsi_out, hsfpsi0.get_pointer(), totbands_per*npwx);
252+
hamilt::Operator<std::complex<double>>::hpsi_info info_sfpsi0(&sfpsi0, allbands, hsfpsi0.get_pointer());
253+
this->phami->ops->hPsi(info_sfpsi0);
256254

257255
psi::Range allndimbands(1,0,0,ndim*totbands_per-1);
258-
hamilt::Operator<std::complex<double>>::hpsi_info info_j1psi(&j1psi, allndimbands);
259-
const std::complex<double>* hj1psi_out = std::get<0>(this->phami->ops->hPsi(info_j1psi))->get_pointer();
260-
ModuleBase::GlobalFunc::COPYARRAY(hj1psi_out, j2psi.get_pointer(), ndim*totbands_per*npwx);
256+
hamilt::Operator<std::complex<double>>::hpsi_info info_j1psi(&j1psi, allndimbands, j2psi.get_pointer());
257+
this->phami->ops->hPsi(info_j1psi);
261258

262-
hamilt::Operator<std::complex<double>>::hpsi_info info_j1sfpsi(&j1sfpsi, allndimbands);
263-
const std::complex<double>* hj1sfpsi_out = std::get<0>(this->phami->ops->hPsi(info_j1sfpsi))->get_pointer();
264-
ModuleBase::GlobalFunc::COPYARRAY(hj1sfpsi_out, j2sfpsi.get_pointer(), ndim*totbands_per*npwx);
259+
hamilt::Operator<std::complex<double>>::hpsi_info info_j1sfpsi(&j1sfpsi, allndimbands, j2sfpsi.get_pointer());
260+
this->phami->ops->hPsi(info_j1sfpsi);
265261

266262
/*
267263
// stohchi.hchi_norm(psi0.get_pointer(), hpsi0.get_pointer(), totbands_per);

source/module_hamilt/ks_pw/operator_pw.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ class OperatorPW : public Operator<std::complex<double>>
1313

1414
//in PW code, different operators donate hPsi independently
1515
//run this->act function for the first operator and run all act() for other nodes in chain table
16-
virtual hpsi_info hPsi(const hpsi_info& input)const
16+
virtual hpsi_info hPsi(hpsi_info& input)const
1717
{
1818
ModuleBase::timer::tick("OperatorPW", "hPsi");
19-
std::tuple<const std::complex<double>*, int> psi_info = std::get<0>(input)->to_range(std::get<1>(input));
19+
auto psi_input = std::get<0>(input);
20+
std::tuple<const std::complex<double>*, int> psi_info = psi_input->to_range(std::get<1>(input));
2021
int n_npwx = std::get<1>(psi_info);
2122

2223
std::complex<double> *tmhpsi = this->get_hpsi(input);
@@ -27,20 +28,25 @@ class OperatorPW : public Operator<std::complex<double>>
2728
ModuleBase::WARNING_QUIT("OperatorPW", "please choose correct range of psi for hPsi()!");
2829
}
2930

30-
this->act(std::get<0>(input), n_npwx, tmpsi_in, tmhpsi);
31+
this->act(psi_input, n_npwx, tmpsi_in, tmhpsi);
3132
OperatorPW* node((OperatorPW*)this->next_op);
3233
while(node != nullptr)
3334
{
34-
node->act(std::get<0>(input), n_npwx, tmpsi_in, tmhpsi);
35+
node->act(psi_input, n_npwx, tmpsi_in, tmhpsi);
3536
node = (OperatorPW*)(node->next_op);
3637
}
3738

38-
//during recursive call of hPsi, delete the input psi
39-
if(this->recursive) delete std::get<0>(input);
40-
4139
ModuleBase::timer::tick("OperatorPW", "hPsi");
4240

43-
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, n_npwx/std::get<0>(input)->npol));
41+
//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
42+
std::complex<double>* hpsi_pointer = std::get<2>(input);
43+
if(this->in_place)
44+
{
45+
ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
46+
delete this->hpsi;
47+
this->hpsi = new psi::Psi<std::complex<double>>(hpsi_pointer, *psi_input, 1, n_npwx/psi_input->npol);
48+
}
49+
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, n_npwx/psi_input->npol), hpsi_pointer);
4450
}
4551

4652
//main function which should be modified in Operator for PW base

source/module_hamilt/operator.h

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include<complex>
55
#include "module_psi/psi.h"
66
#include "module_base/global_function.h"
7+
#include "module_base/tool_quit.h"
78

89
namespace hamilt
910
{
@@ -33,8 +34,12 @@ class Operator
3334
//this is the core function for Operator
3435
// do H|psi> from input |psi> ,
3536
// output of hpsi would be first member of the returned tuple
36-
typedef std::tuple<const psi::Psi<T>*, const psi::Range> hpsi_info;
37-
virtual hpsi_info hPsi(const hpsi_info& input)const {return hpsi_info(nullptr, 0);}
37+
typedef std::tuple<const psi::Psi<T>*, const psi::Range, T*> hpsi_info;
38+
virtual hpsi_info hPsi(hpsi_info& input)const
39+
{
40+
ModuleBase::WARNING_QUIT("Operator::hPsi", "hPsi error!");
41+
return hpsi_info(nullptr, 0, nullptr);
42+
}
3843

3944
virtual void init(const int ik_in)
4045
{
@@ -65,7 +70,7 @@ class Operator
6570
protected:
6671
int ik = 0;
6772

68-
mutable bool recursive = false;
73+
mutable bool in_place = false;
6974

7075
//calculation type, only different type can be in main chain table
7176
int cal_type = 0;
@@ -74,30 +79,39 @@ class Operator
7479
//if this Operator is first node in chain table, hpsi would not be empty
7580
mutable psi::Psi<T>* hpsi = nullptr;
7681

82+
/*This function would analyze hpsi_info and choose how to arrange hpsi storage
83+
In hpsi_info, if the third parameter hpsi_pointer is set, which indicates memory of hpsi is arranged by developer;
84+
if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare.
85+
two cases would occurred:
86+
1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary memory
87+
2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case
88+
*/
7789
T* get_hpsi(const hpsi_info& info)const
7890
{
7991
const int nbands_range = (std::get<1>(info).range_2 - std::get<1>(info).range_1 + 1);
80-
//recursive call of hPsi, hpsi inputs as new psi,
92+
//in_place call of hPsi, hpsi inputs as new psi,
8193
//create a new hpsi and delete old hpsi later
82-
if(this->hpsi != std::get<0>(info) )
94+
T* hpsi_pointer = std::get<2>(info);
95+
const T* psi_pointer = std::get<0>(info)->get_pointer();
96+
if(!hpsi_pointer)
8397
{
84-
this->recursive = false;
85-
if(this->hpsi != nullptr)
86-
{
87-
delete this->hpsi;
88-
}
98+
ModuleBase::WARNING_QUIT("Operator::hPsi", "hpsi_pointer can not be nullptr");
99+
}
100+
else if(hpsi_pointer == psi_pointer)
101+
{
102+
this->in_place = true;
103+
this->hpsi = new psi::Psi<T>(std::get<0>(info)[0], 1, nbands_range);
89104
}
90105
else
91106
{
92-
this->recursive = true;
107+
this->in_place = false;
108+
this->hpsi = new psi::Psi<T>(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range);
93109
}
94-
//create a new hpsi
95-
this->hpsi = new psi::Psi<T>(std::get<0>(info)[0], 1, nbands_range);
96110

97-
T* pointer_hpsi = this->hpsi->get_pointer();
111+
hpsi_pointer = this->hpsi->get_pointer();
98112
size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis();
99-
ModuleBase::GlobalFunc::ZEROS(pointer_hpsi, total_hpsi_size);
100-
return pointer_hpsi;
113+
ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
114+
return hpsi_pointer;
101115
}
102116
};
103117

source/module_hsolver/diago_cg.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,18 @@ void DiagoCG::diag_mock(hamilt::Hamilt *phm_in, psi::Psi<std::complex<double>> &
4848
// Works for generalized eigenvalue problem (US pseudopotentials) as well
4949
//-------------------------------------------------------------------
5050
this->phi_m = new psi::Psi<std::complex<double>>(phi, 1, 1);
51-
this->hphi.resize(this->dim, ModuleBase::ZERO);
52-
this->sphi.resize(this->dim, ModuleBase::ZERO);
51+
this->hphi.resize(this->dmx, ModuleBase::ZERO);
52+
this->sphi.resize(this->dmx, ModuleBase::ZERO);
5353

5454
this->cg = new psi::Psi<std::complex<double>>(phi, 1, 1);
55-
this->scg.resize(this->dim, ModuleBase::ZERO);
56-
this->pphi.resize(this->dim, ModuleBase::ZERO);
55+
this->scg.resize(this->dmx, ModuleBase::ZERO);
56+
this->pphi.resize(this->dmx, ModuleBase::ZERO);
5757

5858
//in band_by_band CG method, only the first band in phi_m would be calculated
5959
psi::Range cg_hpsi_range(0);
6060

61-
this->gradient.resize(this->dim, ModuleBase::ZERO);
62-
this->g0.resize(this->dim, ModuleBase::ZERO);
61+
this->gradient.resize(this->dmx, ModuleBase::ZERO);
62+
this->g0.resize(this->dmx, ModuleBase::ZERO);
6363
this->lagrange.resize(this->n_band, ModuleBase::ZERO);
6464

6565
for (int m = 0; m < this->n_band; m++)
@@ -79,9 +79,8 @@ void DiagoCG::diag_mock(hamilt::Hamilt *phm_in, psi::Psi<std::complex<double>> &
7979

8080
//do hPsi, actually the result of hpsi stored in Operator,
8181
//the necessary of copying operation should be checked later
82-
hp_info cg_hpsi_in(this->phi_m, cg_hpsi_range);
83-
const std::complex<double>* hpsi_out = std::get<0>(phm_in->ops->hPsi(cg_hpsi_in))->get_pointer();
84-
ModuleBase::GlobalFunc::COPYARRAY(hpsi_out, this->hphi.data(), this->dim);
82+
hp_info cg_hpsi_in(this->phi_m, cg_hpsi_range, this->hphi.data());
83+
phm_in->ops->hPsi(cg_hpsi_in);
8584

8685
this->eigenvalue[m] = ModuleBase::GlobalFunc::ddot_real(this->dim, this->phi_m->get_pointer(), this->hphi.data());
8786

@@ -96,9 +95,9 @@ void DiagoCG::diag_mock(hamilt::Hamilt *phm_in, psi::Psi<std::complex<double>> &
9695
this->orthogonal_gradient(phm_in, phi, m);
9796
this->calculate_gamma_cg(iter, gg_last, cg_norm, theta);
9897

99-
hp_info cg_hpsi_in(this->cg, cg_hpsi_range);
100-
const std::complex<double>* cg_hpsi = std::get<0>(phm_in->ops->hPsi(cg_hpsi_in))->get_pointer();
101-
ModuleBase::GlobalFunc::COPYARRAY(cg_hpsi, this->pphi.data(), this->dim);
98+
hp_info cg_hpsi_in(this->cg, cg_hpsi_range, this->pphi.data());
99+
phm_in->ops->hPsi(cg_hpsi_in);
100+
102101
phm_in->sPsi(this->cg->get_pointer(), this->scg.data(), (size_t)this->dim);
103102
converged = this->update_psi(cg_norm, theta, this->eigenvalue[m]);
104103

source/module_hsolver/diago_david.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,8 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
105105
}*/
106106
}
107107
//end of SchmitOrth and calculate H|psi>
108-
hp_info dav_hpsi_in(&basis, psi::Range(1, 0, 0, nband-1));
109-
auto hp_psi = std::get<0>(phm_in->ops->hPsi(dav_hpsi_in));
110-
ModuleBase::GlobalFunc::COPYARRAY(hp_psi->get_pointer(), &hp(0, 0), hp_psi->get_nbasis() * nband);
108+
hp_info dav_hpsi_in(&basis, psi::Range(1, 0, 0, nband-1), &hp(0, 0));
109+
phm_in->ops->hPsi(dav_hpsi_in);
111110

112111
hc.zero_out();
113112
sc.zero_out();
@@ -380,10 +379,10 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
380379
phm_in->sPsi(ppsi, spsi, (size_t)npw);
381380

382381
}
383-
hp_info dav_hpsi_in(&basis, psi::Range(1, 0, nbase, nbase + notconv-1));
384-
auto hp_psi = std::get<0>(phm_in->ops->hPsi(dav_hpsi_in));
385-
ModuleBase::GlobalFunc::COPYARRAY(hp_psi->get_pointer(), &hp(nbase, 0), hp_psi->get_nbasis()*notconv);
386-
382+
//calculate H|psi> for not convergence bands
383+
hp_info dav_hpsi_in(&basis, psi::Range(1, 0, nbase, nbase + notconv-1), &hp(nbase, 0));
384+
phm_in->ops->hPsi(dav_hpsi_in);
385+
387386
ModuleBase::timer::tick("DiagoDavid", "cal_grad");
388387
return;
389388
}

source/module_hsolver/diago_iter_assist.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,14 @@ void DiagoIterAssist::diagH_subspace(hamilt::Hamilt* pHamilt,
5050
//const std::complex<double> *paux = aux;
5151
const std::complex<double> *ppsi = psi.get_pointer();
5252

53+
//allocated hpsi
54+
std::vector<std::complex<double>> hpsi(psi.get_nbands() * psi.get_nbasis());
55+
//do hPsi for all bands
5356
psi::Range all_bands_range(1, psi.get_current_k(), 0, psi.get_nbands()-1);
54-
hamilt::Operator<std::complex<double>>::hpsi_info hpsi_in(&psi, all_bands_range);
55-
const std::complex<double> *aux = std::get<0>(pHamilt->ops->hPsi(hpsi_in))->get_pointer();
57+
hamilt::Operator<std::complex<double>>::hpsi_info hpsi_in(&psi, all_bands_range, hpsi.data());
58+
pHamilt->ops->hPsi(hpsi_in);
59+
//use aux as a data pointer for hpsi
60+
const std::complex<double> *aux = hpsi.data();
5661

5762
char trans1 = 'C';
5863
char trans2 = 'N';
@@ -187,9 +192,14 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* pHamilt,
187192
ModuleBase::GlobalFunc::COPYARRAY(psi.c, psi_temp.get_pointer(), psi_temp.size());
188193
const std::complex<double> *ppsi = psi_temp.get_pointer();
189194

190-
psi::Range all_bands_range(1, 0, 0, nstart-1);
191-
hamilt::Operator<std::complex<double>>::hpsi_info hpsi_in(&psi_temp, all_bands_range);
192-
const std::complex<double> *aux = std::get<0>(pHamilt->ops->hPsi(hpsi_in))->get_pointer();
195+
//allocated hpsi
196+
std::vector<std::complex<double>> hpsi(psi_temp.get_nbands() * psi_temp.get_nbasis());
197+
//do hPsi for all bands
198+
psi::Range all_bands_range(1, psi_temp.get_current_k(), 0, psi_temp.get_nbands()-1);
199+
hamilt::Operator<std::complex<double>>::hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi.data());
200+
pHamilt->ops->hPsi(hpsi_in);
201+
//use aux as a data pointer for hpsi
202+
const std::complex<double> *aux = hpsi.data();
193203

194204
char trans1 = 'C';
195205
char trans2 = 'N';

source/module_psi/psi.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,23 @@ class Psi
9797
else for(size_t index=0; index<this->size();++index) psi[index] = tmp[index];
9898
}
9999
}
100+
101+
//Constructor 5: a wrapper of a data pointer, used for Operator::hPsi()
102+
//in this case, fix_k can not be used
103+
Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in=0)
104+
{
105+
assert(nk_in<=psi_in.get_nk());
106+
if(nband_in == 0)
107+
{
108+
nband_in = psi_in.get_nbands();
109+
}
110+
this->ngk = psi_in.ngk;
111+
this->npol = psi_in.npol;
112+
this->nk = nk_in;
113+
this->nbands = nband_in;
114+
this->nbasis = psi_in.nbasis;
115+
this->psi_current = psi_pointer;
116+
}
100117
// initialize the wavefunction coefficient
101118
// only resize and construct function now is used
102119

0 commit comments

Comments
 (0)