Skip to content

Commit c290f04

Browse files
authored
Perf: Optimize cal_psir_ylm (#1875)
* add psir timer * Perf: optimize cal_psir_ylm, less instruction in core loop * remove comments
1 parent 50faf36 commit c290f04

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

source/module_gint/gint_tools.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "module_hamilt_pw/hamilt_pwdft/global.h"
66
#include "module_hamilt_lcao/hamilt_lcaodft/global_fp.h"
77
#include "module_base/ylm.h"
8+
#include "module_base/timer.h"
89
#include <cmath>
910

1011
namespace Gint_Tools
@@ -159,6 +160,7 @@ namespace Gint_Tools
159160
const bool*const*const cal_flag,
160161
double*const*const psir_ylm) // cal_flag[GlobalC::bigpw->bxyz][na_grid], whether the atom-grid distance is larger than cutoff
161162
{
163+
ModuleBase::timer::tick("Gint_Tools", "cal_psir_ylm");
162164
std::vector<double> ylma;
163165
for (int id=0; id<na_grid; id++)
164166
{
@@ -170,6 +172,19 @@ namespace Gint_Tools
170172
const int iat=GlobalC::GridT.which_atom[mcell_index]; // index of atom
171173
const int it=GlobalC::ucell.iat2it[iat]; // index of atom type
172174
const Atom*const atom=&GlobalC::ucell.atoms[it];
175+
auto &OrbPhi = GlobalC::ORB.Phi[it];
176+
std::vector<const double*> it_psi_uniform(atom->nw);
177+
std::vector<const double*> it_dpsi_uniform(atom->nw);
178+
// preprocess index
179+
for (int iw=0; iw< atom->nw; ++iw)
180+
{
181+
if ( atom->iw2_new[iw] )
182+
{
183+
auto philn = &OrbPhi.PhiLN(atom->iw2l[iw], atom->iw2n[iw]);
184+
it_psi_uniform[iw] = &philn->psi_uniform[0];
185+
it_dpsi_uniform[iw] = &philn->dpsi_uniform[0];
186+
}
187+
}
173188

174189
// meshball_positions should be the bigcell position in meshball
175190
// to the center of meshball.
@@ -227,21 +242,21 @@ namespace Gint_Tools
227242
const double c4 = (dx3-dx2)*delta_r;
228243

229244
double phi=0;
230-
for (int iw=0; iw< atom->nw; ++iw, ++p)
245+
for (int iw=0; iw< atom->nw; ++iw)
231246
{
232247
if ( atom->iw2_new[iw] )
233248
{
234-
const Numerical_Orbital_Lm &philn = GlobalC::ORB.Phi[it].PhiLN(
235-
atom->iw2l[iw],
236-
atom->iw2n[iw]);
237-
phi = c1*philn.psi_uniform[ip] + c2*philn.dpsi_uniform[ip] // radial wave functions
238-
+ c3*philn.psi_uniform[ip+1] + c4*philn.dpsi_uniform[ip+1];
249+
auto psi_uniform = it_psi_uniform[iw];
250+
auto dpsi_uniform = it_dpsi_uniform[iw];
251+
phi = c1*psi_uniform[ip] + c2*dpsi_uniform[ip] // radial wave functions
252+
+ c3*psi_uniform[ip+1] + c4*dpsi_uniform[ip+1];
239253
}
240-
*p=phi * ylma[atom->iw2_ylm[iw]];
254+
p[iw]=phi * ylma[atom->iw2_ylm[iw]];
241255
} // end iw
242256
}// end distance<=(GlobalC::ORB.Phi[it].getRcut()-1.0e-15)
243257
}// end ib
244258
}// end id
259+
ModuleBase::timer::tick("Gint_Tools", "cal_psir_ylm");
245260
return;
246261
}
247262

@@ -257,6 +272,7 @@ namespace Gint_Tools
257272
double*const*const dpsir_ylm_y,
258273
double*const*const dpsir_ylm_z)
259274
{
275+
ModuleBase::timer::tick("Gint_Tools", "cal_dpsir_ylm");
260276
for (int id=0; id<na_grid; id++)
261277
{
262278
const int mcell_index = GlobalC::GridT.bcell_start[grid_index] + id;
@@ -362,7 +378,7 @@ namespace Gint_Tools
362378
}//else
363379
}
364380
}
365-
381+
ModuleBase::timer::tick("Gint_Tools", "cal_dpsir_ylm");
366382
return;
367383
}
368384

@@ -380,6 +396,7 @@ namespace Gint_Tools
380396
double*const*const ddpsir_ylm_yz,
381397
double*const*const ddpsir_ylm_zz)
382398
{
399+
ModuleBase::timer::tick("Gint_Tools", "cal_ddpsir_ylm");
383400
for (int id=0; id<na_grid; id++)
384401
{
385402
const int mcell_index = GlobalC::GridT.bcell_start[grid_index] + id;
@@ -651,7 +668,7 @@ namespace Gint_Tools
651668
}//else
652669
}//end ib
653670
}//end id(atom)
654-
671+
ModuleBase::timer::tick("Gint_Tools", "cal_ddpsir_ylm");
655672
return;
656673
}
657674

@@ -671,6 +688,7 @@ namespace Gint_Tools
671688
double*const*const dpsir_ylm_yz,
672689
double*const*const dpsir_ylm_zz)
673690
{
691+
ModuleBase::timer::tick("Gint_Tools", "cal_dpsirr_ylm");
674692
for (int id=0; id<na_grid; id++)
675693
{
676694
const int mcell_index = GlobalC::GridT.bcell_start[grid_index] + id;
@@ -725,7 +743,7 @@ namespace Gint_Tools
725743
}//else
726744
}
727745
}
728-
746+
ModuleBase::timer::tick("Gint_Tools", "cal_dpsirr_ylm");
729747
return;
730748
}
731749

0 commit comments

Comments
 (0)