@@ -368,10 +368,10 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
368368 // to G space. maybe need fftw with OpenMP
369369 rho_basis->real2recip (aux, aux);
370370
371- std::vector<double > tau_h;
372- std::vector<double > gcar_h;
373371 if (this ->device == base_device::GpuDevice)
374372 {
373+ std::vector<double > tau_h;
374+ std::vector<double > gcar_h;
375375 tau_h.resize (this ->nat * 3 );
376376 for (int iat = 0 ; iat < this ->nat ; ++iat)
377377 {
@@ -389,16 +389,15 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
389389 gcar_h[ig * 3 + 1 ] = rho_basis->gcar [ig].y ;
390390 gcar_h[ig * 3 + 2 ] = rho_basis->gcar [ig].z ;
391391 }
392- }
393- int * iat2it_d = nullptr ;
394- int * ig2gg_d = nullptr ;
395- double * gcar_d = nullptr ;
396- double * tau_d = nullptr ;
397- std::complex <double >* aux_d = nullptr ;
398- double * forcelc_d = nullptr ;
399- double * vloc_d = nullptr ;
400- if (this ->device == base_device::GpuDevice)
401- {
392+
393+ int * iat2it_d = nullptr ;
394+ int * ig2gg_d = nullptr ;
395+ double * gcar_d = nullptr ;
396+ double * tau_d = nullptr ;
397+ std::complex <double >* aux_d = nullptr ;
398+ double * forcelc_d = nullptr ;
399+ double * vloc_d = nullptr ;
400+
402401 resmem_int_op ()(iat2it_d, this ->nat );
403402 resmem_int_op ()(ig2gg_d, rho_basis->npw );
404403 resmem_var_op ()(gcar_d, rho_basis->npw * 3 );
@@ -414,10 +413,7 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
414413 syncmem_complex_h2d_op ()(aux_d, aux, rho_basis->npw );
415414 syncmem_var_h2d_op ()(forcelc_d, forcelc.c , this ->nat * 3 );
416415 syncmem_var_h2d_op ()(vloc_d, vloc.c , vloc.nr * vloc.nc );
417- }
418416
419- if (this ->device == base_device::GpuDevice)
420- {
421417 hamilt::cal_force_loc_op<FPTYPE, Device>()(
422418 this ->nat ,
423419 rho_basis->npw ,
@@ -431,8 +427,16 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
431427 vloc.nc ,
432428 forcelc_d);
433429 syncmem_var_d2h_op ()(forcelc.c , forcelc_d, this ->nat * 3 );
430+
431+ delmem_int_op ()(iat2it_d);
432+ delmem_int_op ()(ig2gg_d);
433+ delmem_var_op ()(gcar_d);
434+ delmem_var_op ()(tau_d);
435+ delmem_complex_op ()(aux_d);
436+ delmem_var_op ()(forcelc_d);
437+ delmem_var_op ()(vloc_d);
434438 }
435- else {
439+ else { // calculate forces on CPU
436440 #ifdef _OPENMP
437441 #pragma omp parallel for
438442 #endif
@@ -457,16 +461,6 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
457461 forcelc (iat, 2 ) *= (ucell.tpiba * ucell.omega );
458462 }
459463 }
460- if (this ->device == base_device::GpuDevice)
461- {
462- delmem_int_op ()(iat2it_d);
463- delmem_int_op ()(ig2gg_d);
464- delmem_var_op ()(gcar_d);
465- delmem_var_op ()(tau_d);
466- delmem_complex_op ()(aux_d);
467- delmem_var_op ()(forcelc_d);
468- delmem_var_op ()(vloc_d);
469- }
470464 // this->print(GlobalV::ofs_running, "local forces", forcelc);
471465 Parallel_Reduce::reduce_pool (forcelc.c , forcelc.nr * forcelc.nc );
472466 delete[] aux;
@@ -482,9 +476,9 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
482476{
483477 ModuleBase::TITLE (" Forces" , " cal_force_ew" );
484478 ModuleBase::timer::tick (" Forces" , " cal_force_ew" );
485-
479+ this -> device = base_device::get_device_type<Device>( this -> ctx );
486480 double fact = 2.0 ;
487- std::complex < double >* aux = new std::complex <double >[ rho_basis->npw ] ;
481+ std::vector< std::complex <double >> aux ( rho_basis->npw ) ;
488482
489483 /*
490484 blocking rho_basis->nrxnpwx for data locality.
@@ -494,9 +488,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
494488 performance will be better when number of atom is quite huge
495489 */
496490 const int block_ig = 1024 ;
497- #ifdef _OPENMP
498491#pragma omp parallel for
499- #endif
500492 for (int igb = 0 ; igb < rho_basis->npw ; igb += block_ig)
501493 {
502494 // calculate the actual task length of this block
@@ -548,9 +540,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
548540 * erfc (sqrt (ucell.tpiba2 * rho_basis->ggecut / 4.0 / alpha));
549541 } while (upperbound > 1.0e-6 );
550542 const int ig0 = rho_basis->ig_gge0 ;
551- #ifdef _OPENMP
552543#pragma omp parallel for
553- #endif
554544 for (int ig = 0 ; ig < rho_basis->npw ; ig++)
555545 {
556546 if (ig== ig0)
@@ -566,83 +556,99 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
566556 {
567557 aux[rho_basis->ig_gge0 ] = std::complex <double >(0.0 , 0.0 );
568558 }
569-
570- #ifdef _OPENMP
571- #pragma omp parallel
559+ if (this ->device == base_device::GpuDevice)
572560 {
573- int num_threads = omp_get_num_threads ();
574- int thread_id = omp_get_thread_num ();
575- #else
576- int num_threads = 1 ;
577- int thread_id = 0 ;
578- #endif
579-
580- /* Here is task distribution for multi-thread,
581- 0. atom will be iterated both in main nat loop and the loop in `if (rho_basis->ig_gge0 >= 0)`.
582- To avoid syncing, we must calculate work range of each thread by our self
583- 1. Calculate the iat range [iat_beg, iat_end) by each thread
584- a. when it is single thread stage, [iat_beg, iat_end) will be [0, nat)
585- 2. each thread iterate atoms form `iat_beg` to `iat_end-1`
586- */
587- int iat_beg, iat_end;
588- int it_beg, ia_beg;
589- ModuleBase::TASK_DIST_1D (num_threads, thread_id, this ->nat , iat_beg, iat_end);
590- iat_end = iat_beg + iat_end;
591- ucell.iat2iait (iat_beg, &ia_beg, &it_beg);
592-
593- int iat = iat_beg;
594- int it = it_beg;
595- int ia = ia_beg;
596-
597- // preprocess ig_gap for skipping the ig point
598- int ig_gap = (rho_basis->ig_gge0 >= 0 && rho_basis->ig_gge0 < rho_basis->npw ) ? rho_basis->ig_gge0 : -1 ;
599-
600- double it_fact = 0 .;
601- int last_it = -1 ;
602-
603- // iterating atoms
604- while (iat < iat_end)
561+ std::vector<double > tau_h (this ->nat * 3 );
562+ std::vector<double > gcar_h (rho_basis->npw * 3 );
563+ for (int iat = 0 ; iat < this ->nat ; ++iat)
605564 {
606- if (it != last_it)
607- { // calculate it_tact when it is changed
608- double zv;
609- {
610- zv = ucell.atoms [it].ncpp .zv ;
611- }
612- it_fact = zv * ModuleBase::e2 * ucell.tpiba * ModuleBase::TWO_PI / ucell.omega * fact;
613- last_it = it;
614- }
565+ int it = ucell.iat2it [iat];
566+ int ia = ucell.iat2ia [iat];
567+ tau_h[iat * 3 ] = ucell.atoms [it].tau [ia].x ;
568+ tau_h[iat * 3 + 1 ] = ucell.atoms [it].tau [ia].y ;
569+ tau_h[iat * 3 + 2 ] = ucell.atoms [it].tau [ia].z ;
570+ }
571+ for (int ig = 0 ; ig < rho_basis->npw ; ++ig)
572+ {
573+ gcar_h[ig * 3 ] = rho_basis->gcar [ig].x ;
574+ gcar_h[ig * 3 + 1 ] = rho_basis->gcar [ig].y ;
575+ gcar_h[ig * 3 + 2 ] = rho_basis->gcar [ig].z ;
576+ }
577+ std::vector<double > it_fact_h (ucell.ntype );
578+ for (int it = 0 ; it < ucell.ntype ; ++it)
579+ {
580+ it_fact_h[it] = ucell.atoms [it].ncpp .zv * ModuleBase::e2 * ucell.tpiba * ModuleBase::TWO_PI / ucell.omega * fact;
581+ }
615582
616- if (ucell.atoms [it].na != 0 )
617- {
618- const auto ig_loop = [&](int ig_beg, int ig_end) {
619- for (int ig = ig_beg; ig < ig_end; ig++)
620- {
621- const ModuleBase::Vector3<double > gcar = rho_basis->gcar [ig];
622- const double arg = ModuleBase::TWO_PI * (gcar * ucell.atoms [it].tau [ia]);
623- double sinp, cosp;
624- ModuleBase::libm::sincos (arg, &sinp, &cosp);
625- double sumnb = -cosp * aux[ig].imag () + sinp * aux[ig].real ();
626- forceion (iat, 0 ) += gcar[0 ] * sumnb;
627- forceion (iat, 1 ) += gcar[1 ] * sumnb;
628- forceion (iat, 2 ) += gcar[2 ] * sumnb;
629- }
630- };
583+ int * iat2it_d = nullptr ;
584+ double * gcar_d = nullptr ;
585+ double * tau_d = nullptr ;
586+ double * it_fact_d = nullptr ;
587+ std::complex <double >* aux_d = nullptr ;
588+ double * forceion_d = nullptr ;
589+ resmem_int_op ()(iat2it_d, this ->nat );
590+ resmem_var_op ()(gcar_d, rho_basis->npw * 3 );
591+ resmem_var_op ()(tau_d, this ->nat * 3 );
592+ resmem_var_op ()(it_fact_d, ucell.ntype );
593+ resmem_complex_op ()(aux_d, rho_basis->npw );
594+ resmem_var_op ()(forceion_d, this ->nat * 3 );
631595
632- // skip ig_gge0 point by separating ig loop into two part
633- ig_loop (0 , ig_gap);
634- ig_loop (ig_gap + 1 , rho_basis->npw );
596+ syncmem_int_h2d_op ()(iat2it_d, ucell.iat2it , this ->nat );
597+ syncmem_var_h2d_op ()(gcar_d, gcar_h.data (), rho_basis->npw * 3 );
598+ syncmem_var_h2d_op ()(tau_d, tau_h.data (), this ->nat * 3 );
599+ syncmem_var_h2d_op ()(it_fact_d, it_fact_h.data (), ucell.ntype );
600+ syncmem_complex_h2d_op ()(aux_d, aux.data (), rho_basis->npw );
601+ syncmem_var_h2d_op ()(forceion_d, forceion.c , this ->nat * 3 );
635602
636- forceion (iat, 0 ) *= it_fact;
637- forceion (iat, 1 ) *= it_fact;
638- forceion (iat, 2 ) *= it_fact;
603+ hamilt::cal_force_ew_op<FPTYPE, Device>()(
604+ this ->nat ,
605+ rho_basis->npw ,
606+ rho_basis->ig_gge0 ,
607+ iat2it_d,
608+ gcar_d,
609+ tau_d,
610+ it_fact_d,
611+ aux_d,
612+ forceion_d);
613+
614+ syncmem_var_d2h_op ()(forceion.c , forceion_d, this ->nat * 3 );
615+ delmem_int_op ()(iat2it_d);
616+ delmem_var_op ()(gcar_d);
617+ delmem_var_op ()(tau_d);
618+ delmem_var_op ()(it_fact_d);
619+ delmem_complex_op ()(aux_d);
620+ delmem_var_op ()(forceion_d);
621+ } else // calculate forces on CPU
622+ {
623+ #pragma omp parallel for
624+ for (int iat = 0 ; iat < this ->nat ; ++iat)
625+ {
626+ const int it = ucell.iat2it [iat];
627+ const int ia = ucell.iat2ia [iat];
628+ double it_fact = ucell.atoms [it].ncpp .zv * ModuleBase::e2 * ucell.tpiba * ModuleBase::TWO_PI / ucell.omega * fact;
639629
640- ++iat;
641- ucell.step_iait (&ia, &it);
630+ for (int ig = 0 ; ig < rho_basis->npw ; ++ig)
631+ {
632+ if (ig != rho_basis->ig_gge0 ) // skip G=0
633+ {
634+ const ModuleBase::Vector3<double > gcar = rho_basis->gcar [ig];
635+ const double arg = ModuleBase::TWO_PI * (gcar * ucell.atoms [it].tau [ia]);
636+ double sinp, cosp;
637+ ModuleBase::libm::sincos (arg, &sinp, &cosp);
638+ double sumnb = -cosp * aux[ig].imag () + sinp * aux[ig].real ();
639+ forceion (iat, 0 ) += gcar[0 ] * sumnb;
640+ forceion (iat, 1 ) += gcar[1 ] * sumnb;
641+ forceion (iat, 2 ) += gcar[2 ] * sumnb;
642+ }
642643 }
644+ forceion (iat, 0 ) *= it_fact;
645+ forceion (iat, 1 ) *= it_fact;
646+ forceion (iat, 2 ) *= it_fact;
643647 }
644-
645- // means that the processor contains G=0 term.
648+ }
649+ // means that the processor contains G=0 term.
650+ #pragma omp parallel
651+ {
646652 if (rho_basis->ig_gge0 >= 0 )
647653 {
648654 double rmax = 5.0 / (sqrt (alpha) * ucell.lat0 );
@@ -651,33 +657,29 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
651657 // output of rgen: the number of vectors in the sphere
652658 const int mxr = 200 ;
653659 // the maximum number of R vectors included in r
654- ModuleBase::Vector3<double >* r = new ModuleBase::Vector3<double >[mxr];
655- double * r2 = new double [mxr];
656- ModuleBase::GlobalFunc::ZEROS (r2, mxr);
657- int * irr = new int [mxr];
658- ModuleBase::GlobalFunc::ZEROS (irr, mxr);
660+ std::vector<ModuleBase::Vector3<double >> r (mxr);
661+ std::vector<double > r2 (mxr);
662+ std::vector<int > irr (mxr);
659663 // the square modulus of R_j-tau_s-tau_s'
660664
661- int iat1 = iat_beg;
662- int T1 = it_beg;
663- int I1 = ia_beg;
664665 const double sqa = sqrt (alpha);
665666 const double sq8a_2pi = sqrt (8.0 * alpha / ModuleBase::TWO_PI);
666667
667668 // iterating atoms.
668- // do not need to sync threads because task range of each thread is isolated
669- while ( iat1 < iat_end )
669+ # pragma omp for
670+ for ( int iat1 = 0 ; iat1 < this -> nat ; iat1++ )
670671 {
671- int iat2 = 0 ; // mohan fix bug 2011-06-07
672- int I2 = 0 ;
673- int T2 = 0 ;
674- while (iat2 < this ->nat )
672+ int T1 = ucell.iat2it [iat1];
673+ int I1 = ucell.iat2ia [iat1];
674+ for (int iat2 = 0 ; iat2 < this ->nat ; iat2++)
675675 {
676- if (iat1 != iat2 && ucell.atoms [T2].na != 0 && ucell.atoms [T1].na != 0 )
676+ int T2 = ucell.iat2it [iat2];
677+ int I2 = ucell.iat2ia [iat2];
678+ if (iat1 != iat2)
677679 {
678680 ModuleBase::Vector3<double > d_tau
679681 = ucell.atoms [T1].tau [I1] - ucell.atoms [T2].tau [I2];
680- H_Ewald_pw::rgen (d_tau, rmax, irr, ucell.latvec , ucell.G , r, r2, nrm);
682+ H_Ewald_pw::rgen (d_tau, rmax, irr. data () , ucell.latvec , ucell.G , r. data () , r2. data () , nrm);
681683
682684 for (int n = 0 ; n < nrm; n++)
683685 {
@@ -686,39 +688,24 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
686688 double factor;
687689 {
688690 factor = ucell.atoms [T1].ncpp .zv * ucell.atoms [T2].ncpp .zv
689- * ModuleBase::e2 / (rr * rr)
690- * (erfc (sqa * rr) / rr + sq8a_2pi * ModuleBase::libm::exp (-alpha * rr * rr))
691- * ucell.lat0 ;
691+ * ModuleBase::e2 / (rr * rr)
692+ * (erfc (sqa * rr) / rr + sq8a_2pi * ModuleBase::libm::exp (-alpha * rr * rr))
693+ * ucell.lat0 ;
692694 }
693-
694695 forceion (iat1, 0 ) -= factor * r[n].x ;
695696 forceion (iat1, 1 ) -= factor * r[n].y ;
696697 forceion (iat1, 2 ) -= factor * r[n].z ;
697698 }
698699 }
699- ++iat2;
700- ucell.step_iait (&I2, &T2);
701700 } // atom b
702- ++iat1;
703- ucell.step_iait (&I1, &T1);
704701 } // atom a
705-
706- delete[] r;
707- delete[] r2;
708- delete[] irr;
709702 }
710- #ifdef _OPENMP
711703 }
712- #endif
713-
714704 Parallel_Reduce::reduce_pool (forceion.c , forceion.nr * forceion.nc );
715-
716705 // this->print(GlobalV::ofs_running, "ewald forces", forceion);
717706
718707 ModuleBase::timer::tick (" Forces" , " cal_force_ew" );
719708
720- delete[] aux;
721-
722709 return ;
723710}
724711
0 commit comments