Skip to content

Commit 889c499

Browse files
authored
Perf(LCAO): Various optimizations of detail code (#1901)
* Perf: optimize folding_vl_k * Perf: use libm in cal_dm_k * Perf: optimize folding_fixedH. Optimize find_offset with binary search * Fix: fix mem leak
1 parent 49d543d commit 889c499

File tree

6 files changed

+252
-172
lines changed

6 files changed

+252
-172
lines changed

source/module_gint/gint_k_pvpr.cpp

Lines changed: 120 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "module_base/memory.h"
1111
#include "module_base/timer.h"
1212
#include "module_base/tool_threading.h"
13+
#include "module_base/libm/libm.h"
1314

1415
void Gint_k::allocate_pvpR(void)
1516
{
@@ -174,48 +175,43 @@ void Gint_k::folding_vl_k(const int &ik, LCAO_Matrix *LM)
174175

175176
// calculate the phase factor exp(ikR).
176177
const double arg = (GlobalC::kv.kvec_d[ ik ] * dR) * ModuleBase::TWO_PI;
177-
std::complex<double> phase = std::complex<double>(cos(arg), sin(arg));
178+
double sinp, cosp;
179+
ModuleBase::libm::sincos(arg, &sinp, &cosp);
180+
std::complex<double> phase = std::complex<double>(cosp, sinp);
178181
int ixxx = DM_start + GlobalC::GridT.find_R2st[iat][nad];
179-
for(int iw=0; iw<atom1->nw; iw++)
182+
183+
if(GlobalV::NSPIN!=4)
180184
{
181-
// iw1_lo
182-
if(GlobalV::NSPIN!=4)
185+
for(int iw=0; iw<atom1->nw; iw++)
183186
{
184187
std::complex<double> *vij = pvp[GlobalC::GridT.trace_lo[start1+iw]];
185-
186188
int* iw2_lo = &GlobalC::GridT.trace_lo[start2];
187-
int* iw2_end = iw2_lo + atom2->nw;
188-
189189
// get the <phi | V | phi>(R) Hamiltonian.
190190
double *vijR = &pvpR_reduced[0][ixxx];
191-
for(; iw2_lo<iw2_end; ++iw2_lo, ++vijR)
191+
for(int iw2 = 0; iw2<atom2->nw; ++iw2)
192192
{
193-
vij[iw2_lo[0]] += vijR[0] * phase;
193+
vij[iw2_lo[iw2]] += vijR[iw2] * phase;
194194
}
195+
ixxx += atom2->nw;
195196
}
196-
else
197+
}
198+
else
199+
{
200+
for(int iw=0; iw<atom1->nw; iw++)
197201
{
198-
std::complex<double> *vij[4];
199-
for(int spin=0;spin<4;spin++)
200-
vij[spin] = pvp_nc[spin][GlobalC::GridT.trace_lo[start1]/GlobalV::NPOL + iw];
201-
202202
int iw2_lo = GlobalC::GridT.trace_lo[start2]/GlobalV::NPOL;
203-
int iw2_end = iw2_lo + atom2->nw;
204-
205-
double *vijR[4];
206203
for(int spin = 0;spin<4;spin++)
207204
{
208-
vijR[spin] = &pvpR_reduced[spin][ixxx];
209-
}
210-
for(; iw2_lo<iw2_end; ++iw2_lo, ++vijR[0], ++vijR[1],++vijR[2],++vijR[3])
211-
{
212-
for(int spin =0;spin<4;spin++)
205+
auto vij = pvp_nc[spin][GlobalC::GridT.trace_lo[start1]/GlobalV::NPOL + iw];
206+
auto vijR = &pvpR_reduced[spin][ixxx];
207+
auto vijs = &vij[iw2_lo];
208+
for(int iw2 = 0; iw2<atom2->nw; ++iw2)
213209
{
214-
vij[spin][iw2_lo] += vijR[spin][0] * phase;
210+
vijs[iw2] += vijR[iw2] * phase;
215211
}
216-
}
212+
}
213+
ixxx += atom2->nw;
217214
}
218-
ixxx += atom2->nw;
219215
}
220216
++nad;
221217
}// end distane<rcut
@@ -231,112 +227,144 @@ void Gint_k::folding_vl_k(const int &ik, LCAO_Matrix *LM)
231227
// Distribution of data.
232228
ModuleBase::timer::tick("Gint_k","Distri");
233229
std::complex<double>* tmp = new std::complex<double>[GlobalV::NLOCAL];
230+
const double sign_table[2] = {1.0, -1.0};
234231
#ifdef _OPENMP
235232
#pragma omp parallel
236233
{
237234
#endif
238235
for (int i=0; i<GlobalV::NLOCAL; i++)
239236
{
240-
#ifdef _OPENMP
241-
#pragma omp for schedule(static, 256)
242-
#endif
243-
for (int j=0; j<GlobalV::NLOCAL; j++)
244-
{
245-
tmp[j] = 0;
246-
}
237+
int i_flag = i & 1; // i % 2 == 0
247238
const int mug = GlobalC::GridT.trace_lo[i];
248239
const int mug0 = mug/GlobalV::NPOL;
249240
// if the row element is on this processor.
250241
if (mug >= 0)
251242
{
243+
if(GlobalV::NSPIN!=4)
244+
{
252245
#ifdef _OPENMP
253-
#pragma omp for schedule(static, 256)
246+
#pragma omp for
254247
#endif
255-
for (int j=0; j<GlobalV::NLOCAL; j++)
256-
{
257-
const int nug = GlobalC::GridT.trace_lo[j];
258-
const int nug0 = nug/GlobalV::NPOL;
259-
// if the col element is on this processor.
260-
if (nug >=0)
248+
for (int j=0; j<GlobalV::NLOCAL; j++)
261249
{
262-
if (mug <= nug)
250+
tmp[j] = 0;
251+
const int nug = GlobalC::GridT.trace_lo[j];
252+
const int nug0 = nug/GlobalV::NPOL;
253+
// if the col element is on this processor.
254+
if (nug >=0)
263255
{
264-
if(GlobalV::NSPIN!=4)
256+
if (mug <= nug)
265257
{
266258
// pvp is symmetric, only half is calculated.
267259
tmp[j] = pvp[mug][nug];
268260
}
269261
else
270262
{
271-
if(i%2==0&&j%2==0)
272-
{
273-
//spin = 0;
274-
tmp[j] = pvp_nc[0][mug0][nug0]+pvp_nc[3][mug0][nug0];
275-
}
276-
else if(i%2==1&&j%2==1)
277-
{
278-
//spin = 3;
279-
tmp[j] = pvp_nc[0][mug0][nug0]-pvp_nc[3][mug0][nug0];
280-
}
281-
else if(i%2==0&&j%2==1)
282-
{
283-
// spin = 1;
284-
if(!GlobalV::DOMAG) tmp[j] = 0;
285-
else tmp[j] = pvp_nc[1][mug0][nug0] - std::complex<double>(0.0,1.0) * pvp_nc[2][mug0][nug0];
286-
}
287-
else if(i%2==1&&j%2==0)
263+
// need to get elements from the other half.
264+
// I have question on this! 2011-02-22
265+
tmp[j] = conj(pvp[nug][mug]);
266+
}
267+
}
268+
}
269+
}
270+
else
271+
{
272+
if (GlobalV::DOMAG)
273+
{
274+
#ifdef _OPENMP
275+
#pragma omp for
276+
#endif
277+
for (int j=0; j<GlobalV::NLOCAL; j++)
278+
{
279+
tmp[j] = 0;
280+
int j_flag = j & 1; // j % 2 == 0
281+
int ij_same = i_flag ^ j_flag ? 0 : 1;
282+
const int nug = GlobalC::GridT.trace_lo[j];
283+
const int nug0 = nug/GlobalV::NPOL;
284+
double sign = sign_table[j_flag];
285+
// if the col element is on this processor.
286+
if (nug >=0)
287+
{
288+
if (mug <= nug)
288289
{
289-
//spin = 2;
290-
if(!GlobalV::DOMAG) tmp[j] = 0;
291-
else tmp[j] = pvp_nc[1][mug0][nug0] + std::complex<double>(0.0,1.0) * pvp_nc[2][mug0][nug0];
290+
if (ij_same)
291+
{
292+
//spin = 0;
293+
//spin = 3;
294+
tmp[j] = pvp_nc[0][mug0][nug0]+sign*pvp_nc[3][mug0][nug0];
295+
}
296+
else
297+
{
298+
// spin = 1;
299+
// spin = 2;
300+
tmp[j] = pvp_nc[1][mug0][nug0] + sign*std::complex<double>(0.0,1.0) * pvp_nc[2][mug0][nug0];
301+
}
292302
}
293303
else
294304
{
295-
ModuleBase::WARNING_QUIT("Gint_k::folding_vl_k_nc","index is wrong!");
296-
}
305+
if (ij_same)
306+
{
307+
//spin = 0;
308+
//spin = 3;
309+
tmp[j] = conj(pvp_nc[0][nug0][mug0]+sign*pvp_nc[3][nug0][mug0]);
310+
}
311+
else
312+
{
313+
// spin = 1;
314+
//spin = 2;
315+
tmp[j] = conj(pvp_nc[1][nug0][mug0] + sign*std::complex<double>(0.0,1.0) * pvp_nc[2][nug0][mug0]);
316+
}
317+
}
297318
}
298319
}
299-
else
320+
}
321+
else
322+
{
323+
#ifdef _OPENMP
324+
#pragma omp for
325+
#endif
326+
for (int j=0; j<GlobalV::NLOCAL; j++)
300327
{
301-
// need to get elements from the other half.
302-
// I have question on this! 2011-02-22
303-
if(GlobalV::NSPIN!=4)
304-
{
305-
tmp[j] = conj(pvp[nug][mug]);
306-
}
307-
else
328+
tmp[j] = 0;
329+
int j_flag = j & 1; // j % 2 == 0
330+
int ij_same = i_flag ^ j_flag ? 0 : 1;
331+
332+
if (!ij_same)
333+
continue;
334+
335+
const int nug = GlobalC::GridT.trace_lo[j];
336+
const int nug0 = nug/GlobalV::NPOL;
337+
double sign = sign_table[j_flag];
338+
// if the col element is on this processor.
339+
if (nug >=0)
308340
{
309-
if(i%2==0&&j%2==0)
341+
if (mug <= nug)
310342
{
311343
//spin = 0;
312-
tmp[j] = conj(pvp_nc[0][nug0][mug0]+pvp_nc[3][nug0][mug0]);
313-
}
314-
else if(i%2==1&&j%2==1)
315-
{
316344
//spin = 3;
317-
tmp[j] = conj(pvp_nc[0][nug0][mug0]-pvp_nc[3][nug0][mug0]);
318-
}
319-
else if(i%2==1&&j%2==0)
320-
{
321-
// spin = 1;
322-
if(!GlobalV::DOMAG) tmp[j] = 0;
323-
else tmp[j] = conj(pvp_nc[1][nug0][mug0] - std::complex<double>(0.0,1.0) * pvp_nc[2][nug0][mug0]);
324-
}
325-
else if(i%2==0&&j%2==1)
326-
{
327-
//spin = 2;
328-
if(!GlobalV::DOMAG) tmp[j] = 0;
329-
else tmp[j] = conj(pvp_nc[1][nug0][mug0] + std::complex<double>(0.0,1.0) * pvp_nc[2][nug0][mug0]);
345+
tmp[j] = pvp_nc[0][mug0][nug0]+sign*pvp_nc[3][mug0][nug0];
330346
}
331347
else
332348
{
333-
ModuleBase::WARNING_QUIT("Gint_k::folding_vl_k_nc","index is wrong!");
334-
}
349+
//spin = 0;
350+
//spin = 3;
351+
tmp[j] = conj(pvp_nc[0][nug0][mug0]+sign*pvp_nc[3][nug0][mug0]);
352+
}
335353
}
336354
}
337355
}
338356
}
339357
}
358+
else
359+
{
360+
#ifdef _OPENMP
361+
#pragma omp for
362+
#endif
363+
for (int j=0; j<GlobalV::NLOCAL; j++)
364+
{
365+
tmp[j] = 0;
366+
}
367+
}
340368
#ifdef _OPENMP
341369
#pragma omp single
342370
{
@@ -352,7 +380,7 @@ void Gint_k::folding_vl_k(const int &ik, LCAO_Matrix *LM)
352380
// according to the HPSEPS's 2D distribution methods.
353381
//-----------------------------------------------------
354382
#ifdef _OPENMP
355-
#pragma omp for schedule(static, 256)
383+
#pragma omp for
356384
#endif
357385
for (int j=0; j<GlobalV::NLOCAL; j++)
358386
{

source/module_gint/gint_vl.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,7 @@ void Gint::cal_meshball_vlocal_gamma(
291291
}
292292
}
293293

294-
inline int find_offset(const int id1, const int id2, const int iat1, const int iat2,
295-
int* find_start, int* find_end)
294+
inline int find_offset(const int id1, const int id2, const int iat1, const int iat2)
296295
{
297296
const int R1x=GlobalC::GridT.ucell_index2x[id1];
298297
const int R2x=GlobalC::GridT.ucell_index2x[id2];
@@ -305,16 +304,8 @@ inline int find_offset(const int id1, const int id2, const int iat1, const int i
305304
const int dRz=R1z-R2z;
306305

307306
const int index=GlobalC::GridT.cal_RindexAtom(dRx, dRy, dRz, iat2);
308-
309-
int offset=-1;
310-
for(int* find=find_start; find < find_end; ++find)
311-
{
312-
if( find[0] == index )
313-
{
314-
offset = find - find_start;
315-
break;
316-
}
317-
}
307+
308+
const int offset = GlobalC::GridT.binary_search_find_R2_offset(index, iat1);
318309

319310
assert(offset < GlobalC::GridT.nad[iat1]);
320311
return offset;
@@ -348,9 +339,6 @@ void Gint::cal_meshball_vlocal_k(
348339
const int T1 = GlobalC::ucell.iat2it[iat1];
349340
const int id1 = GlobalC::GridT.which_unitcell[mcell_index1];
350341
const int DM_start = GlobalC::GridT.nlocstartg[iat1];
351-
// nad : how many adjacent atoms for atom 'iat'
352-
int* find_start = GlobalC::GridT.find_R2[iat1];
353-
int* find_end = GlobalC::GridT.find_R2[iat1] + GlobalC::GridT.nad[iat1];
354342
for(int ia2=0; ia2<na_grid; ++ia2)
355343
{
356344
const int mcell_index2 = GlobalC::GridT.bcell_start[grid_index] + ia2;
@@ -373,8 +361,7 @@ void Gint::cal_meshball_vlocal_k(
373361
const int mcell_index2 = GlobalC::GridT.bcell_start[grid_index] + ia2;
374362
const int id2 = GlobalC::GridT.which_unitcell[mcell_index2];
375363
int offset;
376-
offset=find_offset(id1, id2, iat1, iat2,
377-
find_start, find_end);
364+
offset=find_offset(id1, id2, iat1, iat2);
378365

379366
const int iatw = DM_start + GlobalC::GridT.find_R2st[iat1][offset];
380367

source/module_gint/grid_technique.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ Grid_Technique::~Grid_Technique()
5050
{
5151
delete[] find_R2[iat];
5252
delete[] find_R2st[iat];
53+
delete[] find_R2_sorted_index[iat];
5354
}
5455
delete[] find_R2;
5556
delete[] find_R2st;
57+
delete[] find_R2_sorted_index;
5658
}
5759
}
5860

source/module_gint/grid_technique.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ class Grid_Technique : public Grid_MeshBall
6060

6161
int* nad; // number of adjacent atoms for each atom.
6262
int **find_R2;
63+
int **find_R2_sorted_index;
6364
int **find_R2st;
6465
bool allocate_find_R2;
66+
int binary_search_find_R2_offset(int val, int iat);
6567

6668
//indexes for nnrg -> orbital index + R index
6769
std::vector<gridIntegral::gridIndex> nnrg_index;

source/module_hamilt_lcao/hamilt_lcaodft/DM_k.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "src_parallel/parallel_common.h"
55
#include "module_hamilt_pw/hamilt_pwdft/global.h"
66
#include "local_orbital_charge.h"
7+
#include "module_base/libm/libm.h"
78

89
#ifdef __MKL
910
#include <mkl_service.h>
@@ -92,7 +93,7 @@ inline void cal_DM_ATOM(const Grid_Technique &gt,
9293
const int start2 = GlobalC::ucell.itiaiw2iwt(T2, I2, 0);
9394
const int iw2_lo = gt.trace_lo[start2];
9495
const int nw2 = atom2->nw;
95-
std::complex<double> exp_R = exp(fac
96+
std::complex<double> exp_R = ModuleBase::libm::exp(fac
9697
* (GlobalC::kv.kvec_d[ik].x * RA.info[ia1][ia2][0]
9798
+ GlobalC::kv.kvec_d[ik].y * RA.info[ia1][ia2][1]
9899
+ GlobalC::kv.kvec_d[ik].z * RA.info[ia1][ia2][2]));
@@ -183,7 +184,7 @@ inline void cal_DM_ATOM_nc(const Grid_Technique &gt,
183184
const int start2 = GlobalC::ucell.itiaiw2iwt(T2, I2, 0);
184185
const int iw2_lo = gt.trace_lo[start2] / GlobalV::NPOL + gt.lgd / GlobalV::NPOL * is2;
185186
const int nw2 = atom2->nw;
186-
std::complex<double> exp_R = exp(fac
187+
std::complex<double> exp_R = ModuleBase::libm::exp(fac
187188
* (GlobalC::kv.kvec_d[ik].x * RA.info[ia1][ia2][0]
188189
+ GlobalC::kv.kvec_d[ik].y * RA.info[ia1][ia2][1]
189190
+ GlobalC::kv.kvec_d[ik].z * RA.info[ia1][ia2][2]));

0 commit comments

Comments
 (0)