Skip to content

Commit a79b026

Browse files
authored
Perf: add cuda version of cal_force_ew (#6448)
* refactor cal_force_ew for brevity * refactor cal_force_loc * add cuda version of cal_force_ew
1 parent a87d942 commit a79b026

File tree

3 files changed

+267
-135
lines changed

3 files changed

+267
-135
lines changed

source/source_pw/module_pwdft/forces.cpp

Lines changed: 122 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,10 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
368368
// to G space. maybe need fftw with OpenMP
369369
rho_basis->real2recip(aux, aux);
370370

371-
std::vector<double> tau_h;
372-
std::vector<double> gcar_h;
373371
if(this->device == base_device::GpuDevice)
374372
{
373+
std::vector<double> tau_h;
374+
std::vector<double> gcar_h;
375375
tau_h.resize(this->nat * 3);
376376
for(int iat = 0; iat < this->nat; ++iat)
377377
{
@@ -389,16 +389,15 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
389389
gcar_h[ig * 3 + 1] = rho_basis->gcar[ig].y;
390390
gcar_h[ig * 3 + 2] = rho_basis->gcar[ig].z;
391391
}
392-
}
393-
int* iat2it_d = nullptr;
394-
int* ig2gg_d = nullptr;
395-
double* gcar_d = nullptr;
396-
double* tau_d = nullptr;
397-
std::complex<double>* aux_d = nullptr;
398-
double* forcelc_d = nullptr;
399-
double* vloc_d = nullptr;
400-
if(this->device == base_device::GpuDevice)
401-
{
392+
393+
int* iat2it_d = nullptr;
394+
int* ig2gg_d = nullptr;
395+
double* gcar_d = nullptr;
396+
double* tau_d = nullptr;
397+
std::complex<double>* aux_d = nullptr;
398+
double* forcelc_d = nullptr;
399+
double* vloc_d = nullptr;
400+
402401
resmem_int_op()(iat2it_d, this->nat);
403402
resmem_int_op()(ig2gg_d, rho_basis->npw);
404403
resmem_var_op()(gcar_d, rho_basis->npw * 3);
@@ -414,10 +413,7 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
414413
syncmem_complex_h2d_op()(aux_d, aux, rho_basis->npw);
415414
syncmem_var_h2d_op()(forcelc_d, forcelc.c, this->nat * 3);
416415
syncmem_var_h2d_op()(vloc_d, vloc.c, vloc.nr * vloc.nc);
417-
}
418416

419-
if(this->device == base_device::GpuDevice)
420-
{
421417
hamilt::cal_force_loc_op<FPTYPE, Device>()(
422418
this->nat,
423419
rho_basis->npw,
@@ -431,8 +427,16 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
431427
vloc.nc,
432428
forcelc_d);
433429
syncmem_var_d2h_op()(forcelc.c, forcelc_d, this->nat * 3);
430+
431+
delmem_int_op()(iat2it_d);
432+
delmem_int_op()(ig2gg_d);
433+
delmem_var_op()(gcar_d);
434+
delmem_var_op()(tau_d);
435+
delmem_complex_op()(aux_d);
436+
delmem_var_op()(forcelc_d);
437+
delmem_var_op()(vloc_d);
434438
}
435-
else{
439+
else{ // calculate forces on CPU
436440
#ifdef _OPENMP
437441
#pragma omp parallel for
438442
#endif
@@ -457,16 +461,6 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
457461
forcelc(iat, 2) *= (ucell.tpiba * ucell.omega);
458462
}
459463
}
460-
if(this->device == base_device::GpuDevice)
461-
{
462-
delmem_int_op()(iat2it_d);
463-
delmem_int_op()(ig2gg_d);
464-
delmem_var_op()(gcar_d);
465-
delmem_var_op()(tau_d);
466-
delmem_complex_op()(aux_d);
467-
delmem_var_op()(forcelc_d);
468-
delmem_var_op()(vloc_d);
469-
}
470464
// this->print(GlobalV::ofs_running, "local forces", forcelc);
471465
Parallel_Reduce::reduce_pool(forcelc.c, forcelc.nr * forcelc.nc);
472466
delete[] aux;
@@ -482,9 +476,9 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
482476
{
483477
ModuleBase::TITLE("Forces", "cal_force_ew");
484478
ModuleBase::timer::tick("Forces", "cal_force_ew");
485-
479+
this->device = base_device::get_device_type<Device>(this->ctx);
486480
double fact = 2.0;
487-
std::complex<double>* aux = new std::complex<double>[rho_basis->npw];
481+
std::vector<std::complex<double>> aux(rho_basis->npw);
488482

489483
/*
490484
blocking rho_basis->nrxnpwx for data locality.
@@ -494,9 +488,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
494488
performance will be better when number of atom is quite huge
495489
*/
496490
const int block_ig = 1024;
497-
#ifdef _OPENMP
498491
#pragma omp parallel for
499-
#endif
500492
for (int igb = 0; igb < rho_basis->npw; igb += block_ig)
501493
{
502494
// calculate the actual task length of this block
@@ -548,9 +540,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
548540
* erfc(sqrt(ucell.tpiba2 * rho_basis->ggecut / 4.0 / alpha));
549541
} while (upperbound > 1.0e-6);
550542
const int ig0 = rho_basis->ig_gge0;
551-
#ifdef _OPENMP
552543
#pragma omp parallel for
553-
#endif
554544
for (int ig = 0; ig < rho_basis->npw; ig++)
555545
{
556546
if (ig== ig0)
@@ -566,83 +556,99 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
566556
{
567557
aux[rho_basis->ig_gge0] = std::complex<double>(0.0, 0.0);
568558
}
569-
570-
#ifdef _OPENMP
571-
#pragma omp parallel
559+
if(this->device == base_device::GpuDevice)
572560
{
573-
int num_threads = omp_get_num_threads();
574-
int thread_id = omp_get_thread_num();
575-
#else
576-
int num_threads = 1;
577-
int thread_id = 0;
578-
#endif
579-
580-
/* Here is task distribution for multi-thread,
581-
0. atom will be iterated both in main nat loop and the loop in `if (rho_basis->ig_gge0 >= 0)`.
582-
To avoid syncing, we must calculate work range of each thread by our self
583-
1. Calculate the iat range [iat_beg, iat_end) by each thread
584-
a. when it is single thread stage, [iat_beg, iat_end) will be [0, nat)
585-
2. each thread iterate atoms form `iat_beg` to `iat_end-1`
586-
*/
587-
int iat_beg, iat_end;
588-
int it_beg, ia_beg;
589-
ModuleBase::TASK_DIST_1D(num_threads, thread_id, this->nat, iat_beg, iat_end);
590-
iat_end = iat_beg + iat_end;
591-
ucell.iat2iait(iat_beg, &ia_beg, &it_beg);
592-
593-
int iat = iat_beg;
594-
int it = it_beg;
595-
int ia = ia_beg;
596-
597-
// preprocess ig_gap for skipping the ig point
598-
int ig_gap = (rho_basis->ig_gge0 >= 0 && rho_basis->ig_gge0 < rho_basis->npw) ? rho_basis->ig_gge0 : -1;
599-
600-
double it_fact = 0.;
601-
int last_it = -1;
602-
603-
// iterating atoms
604-
while (iat < iat_end)
561+
std::vector<double> tau_h(this->nat * 3);
562+
std::vector<double> gcar_h(rho_basis->npw * 3);
563+
for(int iat = 0; iat < this->nat; ++iat)
605564
{
606-
if (it != last_it)
607-
{ // calculate it_tact when it is changed
608-
double zv;
609-
{
610-
zv = ucell.atoms[it].ncpp.zv;
611-
}
612-
it_fact = zv * ModuleBase::e2 * ucell.tpiba * ModuleBase::TWO_PI / ucell.omega * fact;
613-
last_it = it;
614-
}
565+
int it = ucell.iat2it[iat];
566+
int ia = ucell.iat2ia[iat];
567+
tau_h[iat * 3] = ucell.atoms[it].tau[ia].x;
568+
tau_h[iat * 3 + 1] = ucell.atoms[it].tau[ia].y;
569+
tau_h[iat * 3 + 2] = ucell.atoms[it].tau[ia].z;
570+
}
571+
for(int ig = 0; ig < rho_basis->npw; ++ig)
572+
{
573+
gcar_h[ig * 3] = rho_basis->gcar[ig].x;
574+
gcar_h[ig * 3 + 1] = rho_basis->gcar[ig].y;
575+
gcar_h[ig * 3 + 2] = rho_basis->gcar[ig].z;
576+
}
577+
std::vector<double> it_fact_h(ucell.ntype);
578+
for(int it = 0; it < ucell.ntype; ++it)
579+
{
580+
it_fact_h[it] = ucell.atoms[it].ncpp.zv * ModuleBase::e2 * ucell.tpiba * ModuleBase::TWO_PI / ucell.omega * fact;
581+
}
615582

616-
if (ucell.atoms[it].na != 0)
617-
{
618-
const auto ig_loop = [&](int ig_beg, int ig_end) {
619-
for (int ig = ig_beg; ig < ig_end; ig++)
620-
{
621-
const ModuleBase::Vector3<double> gcar = rho_basis->gcar[ig];
622-
const double arg = ModuleBase::TWO_PI * (gcar * ucell.atoms[it].tau[ia]);
623-
double sinp, cosp;
624-
ModuleBase::libm::sincos(arg, &sinp, &cosp);
625-
double sumnb = -cosp * aux[ig].imag() + sinp * aux[ig].real();
626-
forceion(iat, 0) += gcar[0] * sumnb;
627-
forceion(iat, 1) += gcar[1] * sumnb;
628-
forceion(iat, 2) += gcar[2] * sumnb;
629-
}
630-
};
583+
int* iat2it_d = nullptr;
584+
double* gcar_d = nullptr;
585+
double* tau_d = nullptr;
586+
double* it_fact_d = nullptr;
587+
std::complex<double>* aux_d = nullptr;
588+
double* forceion_d = nullptr;
589+
resmem_int_op()(iat2it_d, this->nat);
590+
resmem_var_op()(gcar_d, rho_basis->npw * 3);
591+
resmem_var_op()(tau_d, this->nat * 3);
592+
resmem_var_op()(it_fact_d, ucell.ntype);
593+
resmem_complex_op()(aux_d, rho_basis->npw);
594+
resmem_var_op()(forceion_d, this->nat * 3);
631595

632-
// skip ig_gge0 point by separating ig loop into two part
633-
ig_loop(0, ig_gap);
634-
ig_loop(ig_gap + 1, rho_basis->npw);
596+
syncmem_int_h2d_op()(iat2it_d, ucell.iat2it, this->nat);
597+
syncmem_var_h2d_op()(gcar_d, gcar_h.data(), rho_basis->npw * 3);
598+
syncmem_var_h2d_op()(tau_d, tau_h.data(), this->nat * 3);
599+
syncmem_var_h2d_op()(it_fact_d, it_fact_h.data(), ucell.ntype);
600+
syncmem_complex_h2d_op()(aux_d, aux.data(), rho_basis->npw);
601+
syncmem_var_h2d_op()(forceion_d, forceion.c, this->nat * 3);
635602

636-
forceion(iat, 0) *= it_fact;
637-
forceion(iat, 1) *= it_fact;
638-
forceion(iat, 2) *= it_fact;
603+
hamilt::cal_force_ew_op<FPTYPE, Device>()(
604+
this->nat,
605+
rho_basis->npw,
606+
rho_basis->ig_gge0,
607+
iat2it_d,
608+
gcar_d,
609+
tau_d,
610+
it_fact_d,
611+
aux_d,
612+
forceion_d);
613+
614+
syncmem_var_d2h_op()(forceion.c, forceion_d, this->nat * 3);
615+
delmem_int_op()(iat2it_d);
616+
delmem_var_op()(gcar_d);
617+
delmem_var_op()(tau_d);
618+
delmem_var_op()(it_fact_d);
619+
delmem_complex_op()(aux_d);
620+
delmem_var_op()(forceion_d);
621+
} else // calculate forces on CPU
622+
{
623+
#pragma omp parallel for
624+
for(int iat = 0; iat < this->nat; ++iat)
625+
{
626+
const int it = ucell.iat2it[iat];
627+
const int ia = ucell.iat2ia[iat];
628+
double it_fact = ucell.atoms[it].ncpp.zv * ModuleBase::e2 * ucell.tpiba * ModuleBase::TWO_PI / ucell.omega * fact;
639629

640-
++iat;
641-
ucell.step_iait(&ia, &it);
630+
for(int ig = 0; ig < rho_basis->npw; ++ig)
631+
{
632+
if(ig != rho_basis->ig_gge0) // skip G=0
633+
{
634+
const ModuleBase::Vector3<double> gcar = rho_basis->gcar[ig];
635+
const double arg = ModuleBase::TWO_PI * (gcar * ucell.atoms[it].tau[ia]);
636+
double sinp, cosp;
637+
ModuleBase::libm::sincos(arg, &sinp, &cosp);
638+
double sumnb = -cosp * aux[ig].imag() + sinp * aux[ig].real();
639+
forceion(iat, 0) += gcar[0] * sumnb;
640+
forceion(iat, 1) += gcar[1] * sumnb;
641+
forceion(iat, 2) += gcar[2] * sumnb;
642+
}
642643
}
644+
forceion(iat, 0) *= it_fact;
645+
forceion(iat, 1) *= it_fact;
646+
forceion(iat, 2) *= it_fact;
643647
}
644-
645-
// means that the processor contains G=0 term.
648+
}
649+
// means that the processor contains G=0 term.
650+
#pragma omp parallel
651+
{
646652
if (rho_basis->ig_gge0 >= 0)
647653
{
648654
double rmax = 5.0 / (sqrt(alpha) * ucell.lat0);
@@ -651,33 +657,29 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
651657
// output of rgen: the number of vectors in the sphere
652658
const int mxr = 200;
653659
// the maximum number of R vectors included in r
654-
ModuleBase::Vector3<double>* r = new ModuleBase::Vector3<double>[mxr];
655-
double* r2 = new double[mxr];
656-
ModuleBase::GlobalFunc::ZEROS(r2, mxr);
657-
int* irr = new int[mxr];
658-
ModuleBase::GlobalFunc::ZEROS(irr, mxr);
660+
std::vector<ModuleBase::Vector3<double>> r(mxr);
661+
std::vector<double> r2(mxr);
662+
std::vector<int> irr(mxr);
659663
// the square modulus of R_j-tau_s-tau_s'
660664

661-
int iat1 = iat_beg;
662-
int T1 = it_beg;
663-
int I1 = ia_beg;
664665
const double sqa = sqrt(alpha);
665666
const double sq8a_2pi = sqrt(8.0 * alpha / ModuleBase::TWO_PI);
666667

667668
// iterating atoms.
668-
// do not need to sync threads because task range of each thread is isolated
669-
while (iat1 < iat_end)
669+
#pragma omp for
670+
for(int iat1 = 0; iat1 < this->nat; iat1++)
670671
{
671-
int iat2 = 0; // mohan fix bug 2011-06-07
672-
int I2 = 0;
673-
int T2 = 0;
674-
while (iat2 < this->nat)
672+
int T1 = ucell.iat2it[iat1];
673+
int I1 = ucell.iat2ia[iat1];
674+
for(int iat2 = 0; iat2 < this->nat; iat2++)
675675
{
676-
if (iat1 != iat2 && ucell.atoms[T2].na != 0 && ucell.atoms[T1].na != 0)
676+
int T2 = ucell.iat2it[iat2];
677+
int I2 = ucell.iat2ia[iat2];
678+
if (iat1 != iat2)
677679
{
678680
ModuleBase::Vector3<double> d_tau
679681
= ucell.atoms[T1].tau[I1] - ucell.atoms[T2].tau[I2];
680-
H_Ewald_pw::rgen(d_tau, rmax, irr, ucell.latvec, ucell.G, r, r2, nrm);
682+
H_Ewald_pw::rgen(d_tau, rmax, irr.data(), ucell.latvec, ucell.G, r.data(), r2.data(), nrm);
681683

682684
for (int n = 0; n < nrm; n++)
683685
{
@@ -686,39 +688,24 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
686688
double factor;
687689
{
688690
factor = ucell.atoms[T1].ncpp.zv * ucell.atoms[T2].ncpp.zv
689-
* ModuleBase::e2 / (rr * rr)
690-
* (erfc(sqa * rr) / rr + sq8a_2pi * ModuleBase::libm::exp(-alpha * rr * rr))
691-
* ucell.lat0;
691+
* ModuleBase::e2 / (rr * rr)
692+
* (erfc(sqa * rr) / rr + sq8a_2pi * ModuleBase::libm::exp(-alpha * rr * rr))
693+
* ucell.lat0;
692694
}
693-
694695
forceion(iat1, 0) -= factor * r[n].x;
695696
forceion(iat1, 1) -= factor * r[n].y;
696697
forceion(iat1, 2) -= factor * r[n].z;
697698
}
698699
}
699-
++iat2;
700-
ucell.step_iait(&I2, &T2);
701700
} // atom b
702-
++iat1;
703-
ucell.step_iait(&I1, &T1);
704701
} // atom a
705-
706-
delete[] r;
707-
delete[] r2;
708-
delete[] irr;
709702
}
710-
#ifdef _OPENMP
711703
}
712-
#endif
713-
714704
Parallel_Reduce::reduce_pool(forceion.c, forceion.nr * forceion.nc);
715-
716705
// this->print(GlobalV::ofs_running, "ewald forces", forceion);
717706

718707
ModuleBase::timer::tick("Forces", "cal_force_ew");
719708

720-
delete[] aux;
721-
722709
return;
723710
}
724711

0 commit comments

Comments
 (0)