11#include " record_adj.h"
22#include " ../src_pw/global.h"
33#include " ../module_base/timer.h"
4+ #include " ../module_neighbor/sltk_grid_driver.h"
45
5- Record_adj::Record_adj (){}
6+ Record_adj::Record_adj () : iat2ca( nullptr ) {}
67Record_adj::~Record_adj (){}
78
89void Record_adj::delete_grid (void )
@@ -19,6 +20,7 @@ void Record_adj::delete_grid(void)
1920 }
2021 delete[] info;
2122 delete[] na_each;
23+ if (iat2ca) delete[] iat2ca;
2224}
2325
2426
@@ -45,7 +47,7 @@ void Record_adj::for_2d(Parallel_Orbitals &pv, bool gamma_only)
4547 ModuleBase::GlobalFunc::ZEROS (pv.nlocstart , GlobalC::ucell.nat );
4648 pv.nnr = 0 ;
4749 }
48-
50+ {
4951 // (1) find the adjacent atoms of atom[T1,I1];
5052 ModuleBase::Vector3<double > tau1, tau2, dtau;
5153 ModuleBase::Vector3<double > dtau1, dtau2, tau0;
@@ -141,14 +143,26 @@ void Record_adj::for_2d(Parallel_Orbitals &pv, bool gamma_only)
141143 ++iat;
142144 }// end I1
143145 }// end T1
144-
146+ }
145147 // xiaohui add "OUT_LEVEL", 2015-09-16
146148 if (GlobalV::OUT_LEVEL != " m" && !gamma_only) ModuleBase::GlobalFunc::OUT (GlobalV::ofs_running, " ParaV.nnr" , pv.nnr );
147149
148150 // ------------------------------------------------
149151 // info will identify each atom in each unitcell.
150152 // ------------------------------------------------
151153 this ->info = new int **[na_proc];
154+
155+ #ifdef _OPENMP
156+ #pragma omp parallel
157+ {
158+ #endif
159+
160+ ModuleBase::Vector3<double > tau1, tau2, dtau;
161+ ModuleBase::Vector3<double > dtau1, dtau2, tau0;
162+
163+ #ifdef _OPENMP
164+ #pragma omp for schedule(dynamic)
165+ #endif
152166 for (int i=0 ; i<na_proc; i++)
153167 {
154168// GlobalV::ofs_running << " atom" << std::setw(5) << i << std::setw(10) << na_each[i] << std::endl;
@@ -164,23 +178,27 @@ void Record_adj::for_2d(Parallel_Orbitals &pv, bool gamma_only)
164178 }
165179 }
166180
167- iat = 0 ;
168- for (int T1 = 0 ; T1 < GlobalC::ucell.ntype ; ++T1)
181+ #ifdef _OPENMP
182+ #pragma omp for schedule(dynamic)
183+ #endif
184+ for (int iat=0 ; iat<GlobalC::ucell.nat ; ++iat)
169185 {
186+ const int T1 = GlobalC::ucell.iat2it [iat];
170187 Atom* atom1 = &GlobalC::ucell.atoms [T1];
171- for ( int I1 = 0 ; I1 < atom1-> na ; ++I1)
188+ const int I1 = GlobalC::ucell. iat2ia [iat];
172189 {
173190 tau1 = atom1->tau [I1];
174191 // GlobalC::GridD.Find_atom( tau1 );
175- GlobalC::GridD.Find_atom (GlobalC::ucell, tau1 ,T1, I1);
192+ AdjacentAtomInfo adjs;
193+ GlobalC::GridD.Find_atom (GlobalC::ucell, tau1 ,T1, I1, &adjs);
176194
177195 // (2) search among all adjacent atoms.
178196 int cb = 0 ;
179- for (int ad = 0 ; ad < GlobalC::GridD. getAdjacentNum () +1 ; ++ad)
197+ for (int ad = 0 ; ad < adjs. adj_num +1 ; ++ad)
180198 {
181- const int T2 = GlobalC::GridD. getType (ad) ;
182- const int I2 = GlobalC::GridD. getNatom (ad) ;
183- tau2 = GlobalC::GridD. getAdjacentTau (ad) ;
199+ const int T2 = adjs. ntype [ad] ;
200+ const int I2 = adjs. natom [ad] ;
201+ tau2 = adjs. adjacent_tau [ad] ;
184202 dtau = tau2 - tau1;
185203 double distance = dtau.norm () * GlobalC::ucell.lat0 ;
186204 double rcut = GlobalC::ORB.Phi [T1].getRcut () + GlobalC::ORB.Phi [T2].getRcut ();
@@ -190,14 +208,14 @@ void Record_adj::for_2d(Parallel_Orbitals &pv, bool gamma_only)
190208 if (distance < rcut) is_adj = true ;
191209 else if (distance >= rcut)
192210 {
193- for (int ad0 = 0 ; ad0 < GlobalC::GridD. getAdjacentNum () +1 ; ++ad0)
211+ for (int ad0 = 0 ; ad0 < adjs. adj_num +1 ; ++ad0)
194212 {
195- const int T0 = GlobalC::GridD. getType ( ad0) ;
213+ const int T0 = adjs. ntype [ ad0] ;
196214 // const int I0 = GlobalC::GridD.getNatom(ad0);
197215 // const int iat0 = GlobalC::ucell.itia2iat(T0, I0);
198216 // const int start0 = GlobalC::ucell.itiaiw2iwt(T0, I0, 0);
199217
200- tau0 = GlobalC::GridD. getAdjacentTau ( ad0) ;
218+ tau0 = adjs. adjacent_tau [ ad0] ;
201219 dtau1 = tau0 - tau1;
202220 double distance1 = dtau1.norm () * GlobalC::ucell.lat0 ;
203221 double rcut1 = GlobalC::ORB.Phi [T1].getRcut () + GlobalC::ucell.infoNL .Beta [T0].get_rcut_max ();
@@ -216,18 +234,20 @@ void Record_adj::for_2d(Parallel_Orbitals &pv, bool gamma_only)
216234
217235 if (is_adj)
218236 {
219- info[iat][cb][0 ] = GlobalC::GridD. getBox (ad) .x ;
220- info[iat][cb][1 ] = GlobalC::GridD. getBox (ad) .y ;
221- info[iat][cb][2 ] = GlobalC::GridD. getBox (ad) .z ;
237+ info[iat][cb][0 ] = adjs. box [ad] .x ;
238+ info[iat][cb][1 ] = adjs. box [ad] .y ;
239+ info[iat][cb][2 ] = adjs. box [ad] .z ;
222240 info[iat][cb][3 ] = T2;
223241 info[iat][cb][4 ] = I2;
224242 ++cb;
225243 }
226244 }// end ad
227245// GlobalV::ofs_running << " nadj = " << cb << std::endl;
228- ++iat;
229246 }// end I1
230247 }// end T1
248+ #ifdef _OPENMP
249+ }
250+ #endif
231251 ModuleBase::timer::tick (" Record_adj" ," for_2d" );
232252
233253 return ;
@@ -243,48 +263,59 @@ void Record_adj::for_grid(const Grid_Technique >)
243263 ModuleBase::TITLE (" Record_adj" ," for_grid" );
244264 ModuleBase::timer::tick (" Record_adj" ," for_grid" );
245265
246- ModuleBase::Vector3<double > tau1, tau2, dtau;
247- ModuleBase::Vector3<double > tau0, dtau1, dtau2;
248-
249266 this ->na_proc = 0 ;
250- for (int T1=0 ; T1<GlobalC::ucell.ntype ; ++T1)
267+ this ->iat2ca = new int [GlobalC::ucell.nat ];
268+ for (int iat=0 ; iat<GlobalC::ucell.nat ; ++iat)
251269 {
252- for (int I1=0 ; I1<GlobalC::ucell.atoms [T1].na ; ++I1)
253270 {
254- const int iat = GlobalC::ucell.itia2iat (T1,I1);
255271 if (gt.in_this_processor [iat])
256272 {
273+ iat2ca[iat] = na_proc;
257274 ++na_proc;
275+ } else {
276+ iat2ca[iat] = -1 ;
258277 }
259278 }
260279 }
261280
262281 // number of adjacents for each atom.
263282 this ->na_each = new int [na_proc];
264283 ModuleBase::GlobalFunc::ZEROS (na_each, na_proc);
284+ this ->info = new int **[na_proc];
265285
266- int ca = 0 ;
267- for (int T1=0 ; T1<GlobalC::ucell.ntype ; ++T1)
286+ #ifdef _OPENMP
287+ #pragma omp parallel
288+ {
289+ #endif
290+ ModuleBase::Vector3<double > tau1, tau2, dtau;
291+ ModuleBase::Vector3<double > tau0, dtau1, dtau2;
292+
293+ #ifdef _OPENMP
294+ #pragma omp for schedule(dynamic)
295+ #endif
296+ for (int iat=0 ; iat<GlobalC::ucell.nat ; ++iat)
268297 {
298+ const int T1 = GlobalC::ucell.iat2it [iat];
269299 Atom* atom1 = &GlobalC::ucell.atoms [T1];
270- for ( int I1= 0 ; I1<atom1-> na ; ++I1)
300+ const int I1 = GlobalC::ucell. iat2ia [iat];
271301 {
272- const int iat = GlobalC::ucell. itia2iat (T1,I1) ;
302+ const int ca = iat2ca[iat] ;
273303 // key in this function
274304 if (gt.in_this_processor [iat])
275305 {
276306 tau1 = atom1->tau [I1];
277307 // GlobalC::GridD.Find_atom(tau1);
278- GlobalC::GridD.Find_atom (GlobalC::ucell, tau1, T1, I1);
279- for (int ad = 0 ; ad < GlobalC::GridD.getAdjacentNum ()+1 ; ad++)
308+ AdjacentAtomInfo adjs;
309+ GlobalC::GridD.Find_atom (GlobalC::ucell, tau1, T1, I1, &adjs);
310+ for (int ad = 0 ; ad < adjs.adj_num +1 ; ad++)
280311 {
281- const int T2 = GlobalC::GridD. getType (ad) ;
282- const int I2 = GlobalC::GridD. getNatom (ad) ;
312+ const int T2 = adjs. ntype [ad] ;
313+ const int I2 = adjs. natom [ad] ;
283314 const int iat2 = GlobalC::ucell.itia2iat (T2, I2);
284315 if (gt.in_this_processor [iat2])
285316 {
286317 // Atom* atom2 = &GlobalC::ucell.atoms[T2];
287- tau2 = GlobalC::GridD. getAdjacentTau (ad) ;
318+ tau2 = adjs. adjacent_tau [ad] ;
288319 dtau = tau2 - tau1;
289320 double distance = dtau.norm () * GlobalC::ucell.lat0 ;
290321 double rcut = GlobalC::ORB.Phi [T1].getRcut () + GlobalC::ORB.Phi [T2].getRcut ();
@@ -328,13 +359,13 @@ void Record_adj::for_grid(const Grid_Technique >)
328359 }
329360 }// end judge 2
330361 }// end ad
331- ++ca;
332362 }// end judge 1
333363 }// end I1
334364 }// end T1
335365
336-
337- this ->info = new int **[na_proc];
366+ #ifdef _OPENMP
367+ #pragma omp for schedule(dynamic)
368+ #endif
338369 for (int i=0 ; i<na_proc; i++)
339370 {
340371 assert (na_each[i]>0 );
@@ -347,43 +378,47 @@ void Record_adj::for_grid(const Grid_Technique >)
347378 }
348379 }
349380
350- ca = 0 ;
351- for (int T1=0 ; T1<GlobalC::ucell.ntype ; ++T1)
381+ #ifdef _OPENMP
382+ #pragma omp for schedule(dynamic)
383+ #endif
384+ for (int iat=0 ; iat<GlobalC::ucell.nat ; ++iat)
352385 {
386+ const int T1 = GlobalC::ucell.iat2it [iat];
353387 Atom* atom1 = &GlobalC::ucell.atoms [T1];
354- for ( int I1= 0 ; I1 < atom1-> na ; ++I1)
388+ const int I1 = GlobalC::ucell. iat2ia [iat];
355389 {
356- const int iat = GlobalC::ucell. itia2iat (T1,I1) ;
390+ const int ca = iat2ca[iat] ;
357391
358392 // key of this function
359393 if (gt.in_this_processor [iat])
360394 {
361395 tau1 = atom1->tau [I1];
362396 // GlobalC::GridD.Find_atom(tau1);
363- GlobalC::GridD.Find_atom (GlobalC::ucell, tau1, T1, I1);
397+ AdjacentAtomInfo adjs;
398+ GlobalC::GridD.Find_atom (GlobalC::ucell, tau1, T1, I1, &adjs);
364399
365400 int cb = 0 ;
366- for (int ad = 0 ; ad < GlobalC::GridD. getAdjacentNum () +1 ; ad++)
401+ for (int ad = 0 ; ad < adjs. adj_num +1 ; ad++)
367402 {
368- const int T2 = GlobalC::GridD. getType (ad) ;
369- const int I2 = GlobalC::GridD. getNatom (ad) ;
403+ const int T2 = adjs. ntype [ad] ;
404+ const int I2 = adjs. natom [ad] ;
370405 const int iat2 = GlobalC::ucell.itia2iat (T2, I2);
371406
372407 // key of this function
373408 if (gt.in_this_processor [iat2])
374409 {
375410 // Atom* atom2 = &GlobalC::ucell.atoms[T2];
376- tau2 = GlobalC::GridD. getAdjacentTau (ad) ;
411+ tau2 = adjs. adjacent_tau [ad] ;
377412 dtau = tau2 - tau1;
378413 double distance = dtau.norm () * GlobalC::ucell.lat0 ;
379414 double rcut = GlobalC::ORB.Phi [T1].getRcut () + GlobalC::ORB.Phi [T2].getRcut ();
380415
381416 // check the distance
382417 if (distance < rcut)
383418 {
384- info[ca][cb][0 ] = GlobalC::GridD. getBox (ad) .x ;
385- info[ca][cb][1 ] = GlobalC::GridD. getBox (ad) .y ;
386- info[ca][cb][2 ] = GlobalC::GridD. getBox (ad) .z ;
419+ info[ca][cb][0 ] = adjs. box [ad] .x ;
420+ info[ca][cb][1 ] = adjs. box [ad] .y ;
421+ info[ca][cb][2 ] = adjs. box [ad] .z ;
387422 info[ca][cb][3 ] = T2;
388423 info[ca][cb][4 ] = I2;
389424 ++cb;
@@ -425,11 +460,12 @@ void Record_adj::for_grid(const Grid_Technique >)
425460 }// end ad
426461
427462 assert (cb == na_each[ca]);
428- ++ca;
429463 }
430464 }
431465 }
432- assert (ca==na_proc);
466+ #ifdef _OPENMP
467+ }
468+ #endif
433469 ModuleBase::timer::tick (" Record_adj" ," for_grid" );
434470
435471// std::cout << " after for_grid" << std::endl;
0 commit comments