Skip to content

Commit cde47d3

Browse files
committed
code optimization
1 parent a252d05 commit cde47d3

File tree

4 files changed

+111
-46
lines changed

4 files changed

+111
-46
lines changed

source/source_esolver/esolver_of_tddft.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,7 @@ void ESolver_OF_TDDFT::runner(UnitCell& ucell, const int istep)
4747

4848
if (istep==0)
4949
{
50-
this->pphi_td.resize(PARAM.inp.nspin);
51-
52-
for (int is = 0; is < PARAM.inp.nspin; ++is)
53-
{
54-
this->pphi_td[is].resize(this->pw_rho->nrxx);
55-
}
50+
this->phi_td.resize(PARAM.inp.nspin*this->pw_rho->nrxx);
5651
}
5752

5853
if ((istep<1) && PARAM.inp.init_chg != "file")
@@ -85,22 +80,28 @@ void ESolver_OF_TDDFT::runner(UnitCell& ucell, const int istep)
8580
ESolver_FP::iter_finish(ucell, istep, this->iter_, conv_esolver);
8681
}
8782

83+
#ifdef _OPENMP
84+
#pragma omp parallel for collapse(2)
85+
#endif
8886
for (int is = 0; is < PARAM.inp.nspin; ++is)
8987
{
9088
for (int ir = 0; ir < this->pw_rho->nrxx; ++ir)
9189
{
92-
pphi_td[is][ir]=pphi_[is][ir];
90+
phi_td[is*this->pw_rho->nrxx+ir]=pphi_[is][ir];
9391
}
9492
}
9593
}
9694
else
9795
{
98-
this->evolve_ofdft->propagate_psi(this->pelec, this->chr, ucell, this->pphi_td, this->pw_rho);
96+
this->evolve_ofdft->propagate_psi(this->pelec, this->chr, ucell, this->phi_td, this->pw_rho);
97+
#ifdef _OPENMP
98+
#pragma omp parallel for collapse(2)
99+
#endif
99100
for (int is = 0; is < PARAM.inp.nspin; ++is)
100101
{
101102
for (int ir = 0; ir < this->pw_rho->nrxx; ++ir)
102103
{
103-
pphi_[is][ir]=std::abs(pphi_td[is][ir]);
104+
pphi_[is][ir]=std::abs(phi_td[is*this->pw_rho->nrxx+ir]);
104105
}
105106
}
106107
conv_esolver=true;

source/source_esolver/esolver_of_tddft.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class ESolver_OF_TDDFT : public ESolver_OF
1515
virtual void runner(UnitCell& ucell, const int istep) override;
1616

1717
protected:
18-
std::vector<std::vector<std::complex<double>>> pphi_td; // pphi[i] = ppsi.get_pointer(i), which will be freed in ~Psi().
18+
std::vector<std::complex<double>> phi_td; // pphi[i] = ppsi.get_pointer(i), which will be freed in ~Psi().
1919
Evolve_OFDFT* evolve_ofdft=nullptr;
2020
};
2121
} // namespace ModuleESolver

source/source_pw/module_ofdft/evolve_ofdft.cpp

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,59 @@
55

66
#include "source_base/parallel_reduce.h"
77

8-
void Evolve_OFDFT::get_Hpsi(elecstate::ElecState* pelec, const Charge& chr, UnitCell& ucell, std::vector<std::vector<std::complex<double>>> psi_, ModulePW::PW_Basis* pw_rho, std::vector<std::vector<std::complex<double>>> Hpsi)
8+
void Evolve_OFDFT::cal_Hpsi(elecstate::ElecState* pelec,
9+
const Charge& chr,
10+
UnitCell& ucell,
11+
std::vector<std::complex<double>> psi_,
12+
ModulePW::PW_Basis* pw_rho,
13+
std::vector<std::complex<double>> Hpsi)
914
{
1015
// update rho
16+
#ifdef _OPENMP
17+
#pragma omp parallel for collapse(2)
18+
#endif
1119
for (int is = 0; is < PARAM.inp.nspin; ++is)
1220
{
1321
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
1422
{
15-
chr.rho[is][ir] = abs(psi_[is][ir])*abs(psi_[is][ir]);
23+
chr.rho[is][ir] = abs(psi_[is * pw_rho->nrxx + ir])*abs(psi_[is * pw_rho->nrxx + ir]);
1624
}
1725
}
1826

1927
pelec->pot->update_from_charge(&chr, &ucell); // Hartree + XC + external
20-
this->get_tf_potential(chr.rho,pw_rho ,pelec->pot->get_effective_v()); // TF potential
28+
this->cal_tf_potential(chr.rho,pw_rho ,pelec->pot->get_effective_v()); // TF potential
29+
30+
#ifdef _OPENMP
31+
#pragma omp parallel for
32+
#endif
2133
for (int is = 0; is < PARAM.inp.nspin; ++is)
2234
{
2335
const double* vr_eff = pelec->pot->get_effective_v(is);
2436
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
2537
{
26-
Hpsi[is][ir] = vr_eff[ir]*psi_[is][ir];
38+
Hpsi[is * pw_rho->nrxx + ir] = vr_eff[ir]*psi_[is * pw_rho->nrxx + ir];
2739
}
2840
}
29-
this->get_vw_potential_phi(psi_, pw_rho, Hpsi);
41+
this->cal_vw_potential_phi(psi_, pw_rho, Hpsi);
3042
}
3143

32-
void Evolve_OFDFT::get_tf_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpot)
44+
void Evolve_OFDFT::cal_tf_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpot)
3345
{
3446
if (PARAM.inp.nspin == 1)
3547
{
48+
#ifdef _OPENMP
49+
#pragma omp parallel for
50+
#endif
3651
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
3752
{
3853
rpot(0, ir) += 5.0 / 3.0 * this->c_tf_ * std::pow(prho[0][ir], 2. / 3.);
3954
}
4055
}
4156
else if (PARAM.inp.nspin == 2)
4257
{
58+
#ifdef _OPENMP
59+
#pragma omp parallel for collapse(2)
60+
#endif
4361
for (int is = 0; is < PARAM.inp.nspin; ++is)
4462
{
4563
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
@@ -50,17 +68,26 @@ void Evolve_OFDFT::get_tf_potential(const double* const* prho, ModulePW::PW_Basi
5068
}
5169
}
5270

53-
void Evolve_OFDFT::get_vw_potential_phi(std::vector<std::vector<std::complex<double>>> pphi, ModulePW::PW_Basis* pw_rho, std::vector<std::vector<std::complex<double>>> Hpsi)
71+
void Evolve_OFDFT::cal_vw_potential_phi(std::vector<std::complex<double>> pphi,
72+
ModulePW::PW_Basis* pw_rho,
73+
std::vector<std::complex<double>> Hpsi)
5474
{
5575
std::complex<double>** rLapPhi = new std::complex<double>*[PARAM.inp.nspin];
76+
#ifdef _OPENMP
77+
#pragma omp parallel for
78+
#endif
5679
for (int is = 0; is < PARAM.inp.nspin; ++is) {
5780
rLapPhi[is] = new std::complex<double>[pw_rho->nrxx];
5881
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
5982
{
60-
rLapPhi[is][ir]=pphi[is][ir];
83+
rLapPhi[is][ir]=pphi[is * pw_rho->nrxx + ir];
6184
}
6285
}
6386
std::complex<double>** recipPhi = new std::complex<double>*[PARAM.inp.nspin];
87+
88+
#ifdef _OPENMP
89+
#pragma omp parallel for
90+
#endif
6491
for (int is = 0; is < PARAM.inp.nspin; ++is)
6592
{
6693
recipPhi[is] = new std::complex<double>[pw_rho->npw];
@@ -71,12 +98,15 @@ void Evolve_OFDFT::get_vw_potential_phi(std::vector<std::vector<std::complex<dou
7198
recipPhi[is][ik] *= pw_rho->gg[ik] * pw_rho->tpiba2;
7299
}
73100
pw_rho->recip2real(recipPhi[is], rLapPhi[is]);
74-
for (int ik = 0; ik < pw_rho->npw; ++ik)
101+
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
75102
{
76-
Hpsi[is][ik]+=rLapPhi[is][ik];
103+
Hpsi[is * pw_rho->nrxx + ir]+=rLapPhi[is][ir];
77104
}
78105
}
79106

107+
#ifdef _OPENMP
108+
#pragma omp parallel for
109+
#endif
80110
for (int is = 0; is < PARAM.inp.nspin; ++is)
81111
{
82112
delete[] recipPhi[is];
@@ -86,7 +116,9 @@ void Evolve_OFDFT::get_vw_potential_phi(std::vector<std::vector<std::complex<dou
86116
delete[] rLapPhi;
87117
}
88118

89-
void Evolve_OFDFT::get_CD_potential(std::vector<std::vector<std::complex<double>>> psi_, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpot)
119+
void Evolve_OFDFT::cal_CD_potential(std::vector<std::complex<double>> psi_,
120+
ModulePW::PW_Basis* pw_rho,
121+
ModuleBase::matrix& rpot)
90122
{
91123
for (int is = 0; is < PARAM.inp.nspin; ++is)
92124
{
@@ -95,50 +127,68 @@ void Evolve_OFDFT::get_CD_potential(std::vector<std::vector<std::complex<double>
95127
}
96128
}
97129

98-
void Evolve_OFDFT::propagate_psi(elecstate::ElecState* pelec, const Charge& chr, UnitCell& ucell, std::vector<std::vector<std::complex<double>>> pphi_, ModulePW::PW_Basis* pw_rho)
130+
void Evolve_OFDFT::propagate_psi(elecstate::ElecState* pelec,
131+
const Charge& chr, UnitCell& ucell,
132+
std::vector<std::complex<double>> pphi_,
133+
ModulePW::PW_Basis* pw_rho)
99134
{
100135
ModuleBase::timer::tick("ESolver_OF_TDDFT", "propagte_psi");
101136

102137
std::complex<double> imag(0.0,1.0);
103138
double dt=PARAM.inp.mdp.md_dt;
104-
std::vector<std::vector<std::complex<double>>> K1(PARAM.inp.nspin,std::vector<std::complex<double>>(pw_rho->nrxx));
105-
std::vector<std::vector<std::complex<double>>> K2(PARAM.inp.nspin,std::vector<std::complex<double>>(pw_rho->nrxx));
106-
std::vector<std::vector<std::complex<double>>> K3(PARAM.inp.nspin,std::vector<std::complex<double>>(pw_rho->nrxx));
107-
std::vector<std::vector<std::complex<double>>> K4(PARAM.inp.nspin,std::vector<std::complex<double>>(pw_rho->nrxx));
108-
std::vector<std::vector<std::complex<double>>> psi1(PARAM.inp.nspin,std::vector<std::complex<double>>(pw_rho->nrxx));
109-
std::vector<std::vector<std::complex<double>>> psi2(PARAM.inp.nspin,std::vector<std::complex<double>>(pw_rho->nrxx));
110-
std::vector<std::vector<std::complex<double>>> psi3(PARAM.inp.nspin,std::vector<std::complex<double>>(pw_rho->nrxx));
139+
const int nspin = PARAM.inp.nspin;
140+
const int nrxx = pw_rho->nrxx;
141+
const int total_size = nspin * nrxx;
142+
std::vector<std::complex<double>> K1(total_size);
143+
std::vector<std::complex<double>> K2(total_size);
144+
std::vector<std::complex<double>> K3(total_size);
145+
std::vector<std::complex<double>> K4(total_size);
146+
std::vector<std::complex<double>> psi1(total_size);
147+
std::vector<std::complex<double>> psi2(total_size);
148+
std::vector<std::complex<double>> psi3(total_size);
111149

112-
get_Hpsi(pelec,chr,ucell,pphi_,pw_rho,K1);
150+
cal_Hpsi(pelec,chr,ucell,pphi_,pw_rho,K1);
151+
#ifdef _OPENMP
152+
#pragma omp parallel for collapse(2)
153+
#endif
113154
for (int is = 0; is < PARAM.inp.nspin; ++is){
114155
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
115156
{
116-
K1[is][ir]=-1.0*K1[is][ir]*dt*imag;
117-
psi1[is][ir]=pphi_[is][ir]+0.5*K1[is][ir];
157+
K1[is * nrxx + ir]=-1.0*K1[is * nrxx + ir]*dt*imag;
158+
psi1[is * nrxx + ir]=pphi_[is * nrxx + ir]+0.5*K1[is * nrxx + ir];
118159
}
119160
}
120-
get_Hpsi(pelec,chr,ucell,psi1,pw_rho,K2);
161+
cal_Hpsi(pelec,chr,ucell,psi1,pw_rho,K2);
162+
#ifdef _OPENMP
163+
#pragma omp parallel for collapse(2)
164+
#endif
121165
for (int is = 0; is < PARAM.inp.nspin; ++is){
122166
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
123167
{
124-
K2[is][ir]=-1.0*K2[is][ir]*dt*imag;
125-
psi2[is][ir]=pphi_[is][ir]+0.5*K2[is][ir];
168+
K2[is * nrxx + ir]=-1.0*K2[is * nrxx + ir]*dt*imag;
169+
psi2[is * nrxx + ir]=pphi_[is * nrxx + ir]+0.5*K2[is * nrxx + ir];
126170
}
127171
}
128-
get_Hpsi(pelec,chr,ucell,psi2,pw_rho,K3);
172+
cal_Hpsi(pelec,chr,ucell,psi2,pw_rho,K3);
173+
#ifdef _OPENMP
174+
#pragma omp parallel for collapse(2)
175+
#endif
129176
for (int is = 0; is < PARAM.inp.nspin; ++is){
130177
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
131178
{
132-
K3[is][ir]=-1.0*K3[is][ir]*dt*imag;
133-
psi3[is][ir]=pphi_[is][ir]+K3[is][ir];
179+
K3[is * nrxx + ir]=-1.0*K3[is * nrxx + ir]*dt*imag;
180+
psi3[is * nrxx + ir]=pphi_[is * nrxx + ir]+K3[is * nrxx + ir];
134181
}
135182
}
136-
get_Hpsi(pelec,chr,ucell,psi3,pw_rho,K4);
183+
cal_Hpsi(pelec,chr,ucell,psi3,pw_rho,K4);
184+
#ifdef _OPENMP
185+
#pragma omp parallel for collapse(2)
186+
#endif
137187
for (int is = 0; is < PARAM.inp.nspin; ++is){
138188
for (int ir = 0; ir < pw_rho->nrxx; ++ir)
139189
{
140-
K4[is][ir]=-1.0*K4[is][ir]*dt*imag;
141-
pphi_[is][ir]+=1.0/6.0*(K1[is][ir]+2.0*K2[is][ir]+2.0*K3[is][ir]+K4[is][ir]);
190+
K4[is * nrxx + ir]=-1.0*K4[is * nrxx + ir]*dt*imag;
191+
pphi_[is * nrxx + ir]+=1.0/6.0*(K1[is * nrxx + ir]+2.0*K2[is * nrxx + ir]+2.0*K3[is * nrxx + ir]+K4[is * nrxx + ir]);
142192
}
143193
}
144194

source/source_pw/module_ofdft/evolve_ofdft.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,31 @@ class Evolve_OFDFT
2424
~Evolve_OFDFT()
2525
{
2626
}
27-
void propagate_psi(elecstate::ElecState* pelec, const Charge& chr, UnitCell& ucell, std::vector<std::vector<std::complex<double>>> pphi_, ModulePW::PW_Basis* pw_rho);
27+
void propagate_psi(elecstate::ElecState* pelec,
28+
const Charge& chr, UnitCell& ucell,
29+
std::vector<std::complex<double>> pphi_,
30+
ModulePW::PW_Basis* pw_rho);
2831

2932
private:
3033
const double c_tf_
3134
= 3.0 / 10.0 * std::pow(3 * std::pow(M_PI, 2.0), 2.0 / 3.0)
3235
* 2; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
3336

34-
void get_Hpsi(elecstate::ElecState* pelec, const Charge& chr, UnitCell& ucell, std::vector<std::vector<std::complex<double>>> psi_, ModulePW::PW_Basis* pw_rho, std::vector<std::vector<std::complex<double>>> Hpsi);
35-
void get_tf_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpot);
36-
void get_vw_potential_phi(std::vector<std::vector<std::complex<double>>> pphi, ModulePW::PW_Basis* pw_rho, std::vector<std::vector<std::complex<double>>> Hpsi); // -1/2 \nabla^2 \phi
37-
void get_CD_potential(std::vector<std::vector<std::complex<double>>> psi_, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpot);
37+
void cal_Hpsi(elecstate::ElecState* pelec,
38+
const Charge& chr,
39+
UnitCell& ucell,
40+
std::vector<std::complex<double>> psi_,
41+
ModulePW::PW_Basis* pw_rho,
42+
std::vector<std::complex<double>> Hpsi);
43+
void cal_tf_potential(const double* const* prho,
44+
ModulePW::PW_Basis* pw_rho,
45+
ModuleBase::matrix& rpot);
46+
void cal_vw_potential_phi(std::vector<std::complex<double>> pphi,
47+
ModulePW::PW_Basis* pw_rho,
48+
std::vector<std::complex<double>> Hpsi); // -1/2 \nabla^2 \phi
49+
void cal_CD_potential(std::vector<std::complex<double>> psi_,
50+
ModulePW::PW_Basis* pw_rho,
51+
ModuleBase::matrix& rpot);
3852

3953
};
4054
#endif

0 commit comments

Comments
 (0)