@@ -329,7 +329,7 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
329329{
330330 ModuleBase::TITLE (" Forces" , " cal_force_loc" );
331331 ModuleBase::timer::tick (" Forces" , " cal_force_loc" );
332-
332+ this -> device = base_device::get_device_type<Device>( this -> ctx );
333333 std::complex <double >* aux = new std::complex <double >[rho_basis->nmaxgr ];
334334 // now, in all pools , the charge are the same,
335335 // so, the force calculated by each pool is equal.
@@ -368,30 +368,105 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
368368 // to G space. maybe need fftw with OpenMP
369369 rho_basis->real2recip (aux, aux);
370370
371- #ifdef _OPENMP
372- #pragma omp parallel for
373- #endif
374- for (int iat = 0 ; iat < this ->nat ; ++iat)
371+ std::vector<double > tau_h;
372+ std::vector<double > gcar_h;
373+ if (this ->device == base_device::GpuDevice)
375374 {
376- // read `it` `ia` from the table
377- int it = ucell.iat2it [iat];
378- int ia = ucell.iat2ia [iat];
379- for (int ig = 0 ; ig < rho_basis->npw ; ig++)
375+ tau_h.resize (this ->nat * 3 );
376+ for (int iat = 0 ; iat < this ->nat ; ++iat)
380377 {
381- const double phase = ModuleBase::TWO_PI * (rho_basis->gcar [ig] * ucell.atoms [it].tau [ia]);
382- double sinp, cosp;
383- ModuleBase::libm::sincos (phase, &sinp, &cosp);
384- const double factor
385- = vloc (it, rho_basis->ig2igg [ig]) * (cosp * aux[ig].imag () + sinp * aux[ig].real ());
386- forcelc (iat, 0 ) += rho_basis->gcar [ig][0 ] * factor;
387- forcelc (iat, 1 ) += rho_basis->gcar [ig][1 ] * factor;
388- forcelc (iat, 2 ) += rho_basis->gcar [ig][2 ] * factor;
378+ int it = ucell.iat2it [iat];
379+ int ia = ucell.iat2ia [iat];
380+ tau_h[iat * 3 ] = ucell.atoms [it].tau [ia].x ;
381+ tau_h[iat * 3 + 1 ] = ucell.atoms [it].tau [ia].y ;
382+ tau_h[iat * 3 + 2 ] = ucell.atoms [it].tau [ia].z ;
389383 }
390- forcelc (iat, 0 ) *= (ucell.tpiba * ucell.omega );
391- forcelc (iat, 1 ) *= (ucell.tpiba * ucell.omega );
392- forcelc (iat, 2 ) *= (ucell.tpiba * ucell.omega );
384+
385+ gcar_h.resize (rho_basis->npw * 3 );
386+ for (int ig = 0 ; ig < rho_basis->npw ; ++ig)
387+ {
388+ gcar_h[ig * 3 ] = rho_basis->gcar [ig].x ;
389+ gcar_h[ig * 3 + 1 ] = rho_basis->gcar [ig].y ;
390+ gcar_h[ig * 3 + 2 ] = rho_basis->gcar [ig].z ;
391+ }
392+ }
393+ int * iat2it_d = nullptr ;
394+ int * ig2gg_d = nullptr ;
395+ double * gcar_d = nullptr ;
396+ double * tau_d = nullptr ;
397+ std::complex <double >* aux_d = nullptr ;
398+ double * forcelc_d = nullptr ;
399+ double * vloc_d = nullptr ;
400+ if (this ->device == base_device::GpuDevice)
401+ {
402+ resmem_int_op ()(iat2it_d, this ->nat );
403+ resmem_int_op ()(ig2gg_d, rho_basis->npw );
404+ resmem_var_op ()(gcar_d, rho_basis->npw * 3 );
405+ resmem_var_op ()(tau_d, this ->nat * 3 );
406+ resmem_complex_op ()(aux_d, rho_basis->npw );
407+ resmem_var_op ()(forcelc_d, this ->nat * 3 );
408+ resmem_var_op ()(vloc_d, vloc.nr * vloc.nc );
409+
410+ syncmem_int_h2d_op ()(iat2it_d, ucell.iat2it , this ->nat );
411+ syncmem_int_h2d_op ()(ig2gg_d, rho_basis->ig2igg , rho_basis->npw );
412+ syncmem_var_h2d_op ()(gcar_d, gcar_h.data (), rho_basis->npw * 3 );
413+ syncmem_var_h2d_op ()(tau_d, tau_h.data (), this ->nat * 3 );
414+ syncmem_complex_h2d_op ()(aux_d, aux, rho_basis->npw );
415+ syncmem_var_h2d_op ()(forcelc_d, forcelc.c , this ->nat * 3 );
416+ syncmem_var_h2d_op ()(vloc_d, vloc.c , vloc.nr * vloc.nc );
393417 }
394418
419+ if (this ->device == base_device::GpuDevice)
420+ {
421+ hamilt::cal_force_loc_op<FPTYPE, Device>()(
422+ this ->nat ,
423+ rho_basis->npw ,
424+ ucell.tpiba * ucell.omega ,
425+ iat2it_d,
426+ ig2gg_d,
427+ gcar_d,
428+ tau_d,
429+ aux_d,
430+ vloc_d,
431+ vloc.nc ,
432+ forcelc_d);
433+ syncmem_var_d2h_op ()(forcelc.c , forcelc_d, this ->nat * 3 );
434+ }
435+ else {
436+ #ifdef _OPENMP
437+ #pragma omp parallel for
438+ #endif
439+ for (int iat = 0 ; iat < this ->nat ; ++iat)
440+ {
441+ // read `it` `ia` from the table
442+ int it = ucell.iat2it [iat];
443+ int ia = ucell.iat2ia [iat];
444+ for (int ig = 0 ; ig < rho_basis->npw ; ig++)
445+ {
446+ const double phase = ModuleBase::TWO_PI * (rho_basis->gcar [ig] * ucell.atoms [it].tau [ia]);
447+ double sinp, cosp;
448+ ModuleBase::libm::sincos (phase, &sinp, &cosp);
449+ const double factor
450+ = vloc (it, rho_basis->ig2igg [ig]) * (cosp * aux[ig].imag () + sinp * aux[ig].real ());
451+ forcelc (iat, 0 ) += rho_basis->gcar [ig][0 ] * factor;
452+ forcelc (iat, 1 ) += rho_basis->gcar [ig][1 ] * factor;
453+ forcelc (iat, 2 ) += rho_basis->gcar [ig][2 ] * factor;
454+ }
455+ forcelc (iat, 0 ) *= (ucell.tpiba * ucell.omega );
456+ forcelc (iat, 1 ) *= (ucell.tpiba * ucell.omega );
457+ forcelc (iat, 2 ) *= (ucell.tpiba * ucell.omega );
458+ }
459+ }
460+ if (this ->device == base_device::GpuDevice)
461+ {
462+ delmem_int_op ()(iat2it_d);
463+ delmem_int_op ()(ig2gg_d);
464+ delmem_var_op ()(gcar_d);
465+ delmem_var_op ()(tau_d);
466+ delmem_complex_op ()(aux_d);
467+ delmem_var_op ()(forcelc_d);
468+ delmem_var_op ()(vloc_d);
469+ }
395470 // this->print(GlobalV::ofs_running, "local forces", forcelc);
396471 Parallel_Reduce::reduce_pool (forcelc.c , forcelc.nr * forcelc.nc );
397472 delete[] aux;
0 commit comments