Skip to content

Commit ebe8939

Browse files
authored
Merge pull request #1188 from deepmodeling/HSolver
Refactor: modified Operator::hPsi(), hpsi memory arranged outside
2 parents 0752886 + e15654f commit ebe8939

File tree

17 files changed

+159
-99
lines changed

17 files changed

+159
-99
lines changed

source/module_esolver/esolver_sdft_pw_tool.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,18 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
246246
// this->phami->hPsi(j1psi.get_pointer(), j2psi.get_pointer(), ndim*totbands_per*npwx);
247247
// this->phami->hPsi(j1sfpsi.get_pointer(), j2sfpsi.get_pointer(), ndim*totbands_per*npwx);
248248
psi::Range allbands(1,0,0,totbands_per-1);
249-
hamilt::Operator::hpsi_info info_psi0(&psi0, allbands);
250-
const std::complex<double>* hpsi_out = std::get<0>(this->phami->ops->hPsi(info_psi0))->get_pointer();
251-
ModuleBase::GlobalFunc::COPYARRAY(hpsi_out, hpsi0.get_pointer(), totbands_per*npwx);
249+
hamilt::Operator<std::complex<double>>::hpsi_info info_psi0(&psi0, allbands, hpsi0.get_pointer());
250+
this->phami->ops->hPsi(info_psi0);
252251

253-
hamilt::Operator::hpsi_info info_sfpsi0(&sfpsi0, allbands);
254-
const std::complex<double>* hsfpsi_out = std::get<0>(this->phami->ops->hPsi(info_sfpsi0))->get_pointer();
255-
ModuleBase::GlobalFunc::COPYARRAY(hsfpsi_out, hsfpsi0.get_pointer(), totbands_per*npwx);
252+
hamilt::Operator<std::complex<double>>::hpsi_info info_sfpsi0(&sfpsi0, allbands, hsfpsi0.get_pointer());
253+
this->phami->ops->hPsi(info_sfpsi0);
256254

257255
psi::Range allndimbands(1,0,0,ndim*totbands_per-1);
258-
hamilt::Operator::hpsi_info info_j1psi(&j1psi, allndimbands);
259-
const std::complex<double>* hj1psi_out = std::get<0>(this->phami->ops->hPsi(info_j1psi))->get_pointer();
260-
ModuleBase::GlobalFunc::COPYARRAY(hj1psi_out, j2psi.get_pointer(), ndim*totbands_per*npwx);
256+
hamilt::Operator<std::complex<double>>::hpsi_info info_j1psi(&j1psi, allndimbands, j2psi.get_pointer());
257+
this->phami->ops->hPsi(info_j1psi);
261258

262-
hamilt::Operator::hpsi_info info_j1sfpsi(&j1sfpsi, allndimbands);
263-
const std::complex<double>* hj1sfpsi_out = std::get<0>(this->phami->ops->hPsi(info_j1sfpsi))->get_pointer();
264-
ModuleBase::GlobalFunc::COPYARRAY(hj1sfpsi_out, j2sfpsi.get_pointer(), ndim*totbands_per*npwx);
259+
hamilt::Operator<std::complex<double>>::hpsi_info info_j1sfpsi(&j1sfpsi, allndimbands, j2sfpsi.get_pointer());
260+
this->phami->ops->hPsi(info_j1sfpsi);
265261

266262
/*
267263
// stohchi.hchi_norm(psi0.get_pointer(), hpsi0.get_pointer(), totbands_per);

source/module_hamilt/hamilt.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class Hamilt
3434
int non_first_scf=0;
3535

3636
// first node operator, add operations from each operators
37-
Operator* ops = nullptr;
37+
Operator<std::complex<double>>* ops = nullptr;
3838
};
3939

4040
} // namespace hamilt

source/module_hamilt/hamilt_pw.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ HamiltPW::HamiltPW()
2525

2626
if (GlobalV::T_IN_H)
2727
{
28-
Operator* ekinetic = new Ekinetic<OperatorPW>(
28+
Operator<std::complex<double>>* ekinetic = new Ekinetic<OperatorPW>(
2929
tpiba2,
3030
gk2,
3131
GlobalC::wfcpw->npwk_max
@@ -41,7 +41,7 @@ HamiltPW::HamiltPW()
4141
}
4242
if (GlobalV::VL_IN_H)
4343
{
44-
Operator* veff = new Veff<OperatorPW>(
44+
Operator<std::complex<double>>* veff = new Veff<OperatorPW>(
4545
isk,
4646
&(GlobalC::pot.vr_eff),
4747
GlobalC::wfcpw
@@ -57,7 +57,7 @@ HamiltPW::HamiltPW()
5757
}
5858
if (GlobalV::VNL_IN_H)
5959
{
60-
Operator* nonlocal = new Nonlocal<OperatorPW>(
60+
Operator<std::complex<double>>* nonlocal = new Nonlocal<OperatorPW>(
6161
isk,
6262
&GlobalC::ppcell,
6363
&GlobalC::ucell
@@ -71,7 +71,7 @@ HamiltPW::HamiltPW()
7171
this->ops->add(nonlocal);
7272
}
7373
}
74-
Operator* meta = new Meta<OperatorPW>(
74+
Operator<std::complex<double>>* meta = new Meta<OperatorPW>(
7575
tpiba,
7676
isk,
7777
&GlobalC::pot.vofk,

source/module_hamilt/ks_pw/ekinetic_pw.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
namespace hamilt
77
{
88

9-
template class Ekinetic<OperatorPW>;
10-
11-
template<>
129
Ekinetic<OperatorPW>::Ekinetic(
1310
double tpiba2_in,
1411
const double* gk2_in,
@@ -26,7 +23,6 @@ Ekinetic<OperatorPW>::Ekinetic(
2623
}
2724
}
2825

29-
template<>
3026
void Ekinetic<OperatorPW>::act
3127
(
3228
const psi::Psi<std::complex<double>> *psi_in,

source/module_hamilt/ks_pw/ekinetic_pw.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,16 @@
66
namespace hamilt
77
{
88

9-
template<class T>
10-
class Ekinetic : public T
9+
#ifndef __EKINETICTEMPLATE
10+
#define __EKINETICTEMPLATE
11+
12+
template<class T> class Ekinetic : public T
13+
{};
14+
15+
#endif
16+
17+
template<>
18+
class Ekinetic<OperatorPW> : public OperatorPW
1119
{
1220
public:
1321
Ekinetic(

source/module_hamilt/ks_pw/meta_pw.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
namespace hamilt
1111
{
1212

13-
template class Meta<OperatorPW>;
14-
15-
template<>
1613
Meta<OperatorPW>::Meta(
1714
double tpiba_in,
1815
const int* isk_in,
@@ -31,7 +28,6 @@ Meta<OperatorPW>::Meta(
3128
}
3229
}
3330

34-
template<>
3531
void Meta<OperatorPW>::act
3632
(
3733
const psi::Psi<std::complex<double>> *psi_in,

source/module_hamilt/ks_pw/meta_pw.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,16 @@
88
namespace hamilt
99
{
1010

11-
template<class T>
12-
class Meta : public T
11+
#ifndef __METATEMPLATE
12+
#define __METATEMPLATE
13+
14+
template<class T> class Meta : public T
15+
{};
16+
17+
#endif
18+
19+
template<>
20+
class Meta<OperatorPW> : public OperatorPW
1321
{
1422
public:
1523
Meta(

source/module_hamilt/ks_pw/nonlocal_pw.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
namespace hamilt
1111
{
1212

13-
template class Nonlocal<OperatorPW>;
14-
15-
template<>
1613
Nonlocal<OperatorPW>::Nonlocal
1714
(
1815
const int* isk_in,
@@ -30,7 +27,6 @@ Nonlocal<OperatorPW>::Nonlocal
3027
}
3128
}
3229

33-
template<>
3430
void Nonlocal<OperatorPW>::init(const int ik_in)
3531
{
3632
this->ik = ik_in;
@@ -49,7 +45,6 @@ void Nonlocal<OperatorPW>::init(const int ik_in)
4945
//--------------------------------------------------------------------------
5046
// this function sum up each non-local pseudopotential located on each atom,
5147
//--------------------------------------------------------------------------
52-
template<>
5348
void Nonlocal<OperatorPW>::add_nonlocal_pp(std::complex<double> *hpsi_in, const std::complex<double> *becp, const int m) const
5449
{
5550
ModuleBase::timer::tick("Nonlocal", "add_nonlocal_pp");
@@ -170,7 +165,6 @@ void Nonlocal<OperatorPW>::add_nonlocal_pp(std::complex<double> *hpsi_in, const
170165
return;
171166
}
172167

173-
template<>
174168
void Nonlocal<OperatorPW>::act
175169
(
176170
const psi::Psi<std::complex<double>> *psi_in,

source/module_hamilt/ks_pw/nonlocal_pw.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,16 @@
1010
namespace hamilt
1111
{
1212

13-
template<class T>
14-
class Nonlocal : public T
13+
#ifndef __NONLOCALTEMPLATE
14+
#define __NONLOCALTEMPLATE
15+
16+
template<class T> class Nonlocal : public T
17+
{};
18+
19+
#endif
20+
21+
template<>
22+
class Nonlocal<OperatorPW> : public OperatorPW
1523
{
1624
public:
1725
Nonlocal(

source/module_hamilt/ks_pw/operator_pw.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66
namespace hamilt
77
{
88

9-
class OperatorPW : public Operator
9+
class OperatorPW : public Operator<std::complex<double>>
1010
{
1111
public:
1212
virtual ~OperatorPW(){};
1313

1414
//in PW code, different operators donate hPsi independently
1515
//run this->act function for the first operator and run all act() for other nodes in chain table
16-
virtual hpsi_info hPsi(const hpsi_info& input)const
16+
virtual hpsi_info hPsi(hpsi_info& input)const
1717
{
1818
ModuleBase::timer::tick("OperatorPW", "hPsi");
19-
std::tuple<const std::complex<double>*, int> psi_info = std::get<0>(input)->to_range(std::get<1>(input));
19+
auto psi_input = std::get<0>(input);
20+
std::tuple<const std::complex<double>*, int> psi_info = psi_input->to_range(std::get<1>(input));
2021
int n_npwx = std::get<1>(psi_info);
2122

2223
std::complex<double> *tmhpsi = this->get_hpsi(input);
@@ -27,20 +28,25 @@ class OperatorPW : public Operator
2728
ModuleBase::WARNING_QUIT("OperatorPW", "please choose correct range of psi for hPsi()!");
2829
}
2930

30-
this->act(std::get<0>(input), n_npwx, tmpsi_in, tmhpsi);
31+
this->act(psi_input, n_npwx, tmpsi_in, tmhpsi);
3132
OperatorPW* node((OperatorPW*)this->next_op);
3233
while(node != nullptr)
3334
{
34-
node->act(std::get<0>(input), n_npwx, tmpsi_in, tmhpsi);
35+
node->act(psi_input, n_npwx, tmpsi_in, tmhpsi);
3536
node = (OperatorPW*)(node->next_op);
3637
}
3738

38-
//during recursive call of hPsi, delete the input psi
39-
if(this->recursive) delete std::get<0>(input);
40-
4139
ModuleBase::timer::tick("OperatorPW", "hPsi");
4240

43-
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, n_npwx/std::get<0>(input)->npol));
41+
//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
42+
std::complex<double>* hpsi_pointer = std::get<2>(input);
43+
if(this->in_place)
44+
{
45+
ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
46+
delete this->hpsi;
47+
this->hpsi = new psi::Psi<std::complex<double>>(hpsi_pointer, *psi_input, 1, n_npwx/psi_input->npol);
48+
}
49+
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, n_npwx/psi_input->npol), hpsi_pointer);
4450
}
4551

4652
//main function which should be modified in Operator for PW base

0 commit comments

Comments
 (0)