Skip to content

Commit dc258a5

Browse files
authored
Perf: Better OpenMP Parallelization for cal_force_cc and cal_force_scc (#1840)
1 parent 728ca21 commit dc258a5

File tree

1 file changed

+98
-87
lines changed

1 file changed

+98
-87
lines changed

source/module_hamilt_pw/hamilt_pwdft/forces.cpp

Lines changed: 98 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,6 @@ void Forces<FPTYPE, Device>::cal_force_cc(ModuleBase::matrix& forcecc, ModulePW:
727727
{
728728
ModuleBase::TITLE("Forces", "cal_force_cc");
729729
// recalculate the exchange-correlation potential.
730-
ModuleBase::TITLE("Forces", "cal_force_cc");
731730
ModuleBase::timer::tick("Forces", "cal_force_cc");
732731

733732
int total_works = 0;
@@ -800,99 +799,79 @@ void Forces<FPTYPE, Device>::cal_force_cc(ModuleBase::matrix& forcecc, ModulePW:
800799
// to G space
801800
rho_basis->real2recip(psiv, psiv);
802801

803-
#ifdef _OPENMP
804-
#pragma omp parallel
805-
{
806-
int num_threads = omp_get_num_threads();
807-
int thread_id = omp_get_thread_num();
808-
#else
809-
int num_threads = 1;
810-
int thread_id = 0;
811-
#endif
812-
813802
// psiv contains now Vxc(G)
814803
double* rhocg = new double[rho_basis->ngg];
815804
ModuleBase::GlobalFunc::ZEROS(rhocg, rho_basis->ngg);
816805

817-
/* Here is task distribution for multi-thread,
818-
0. Consider for load balancing, we distribute `total_works` instead of `nat`.
819-
So simply use `#pragma omp parallel for` is not suitable
820-
1. Calculate the work range [work_beg, work_end) by each thread
821-
a. when it is single thread stage, [work_beg, work_end) will be [0, total_works)
822-
2. each thread iterate atoms form `work_beg` to `work_end-1`
823-
*/
824-
int work, work_end;
825-
ModuleBase::TASK_DIST_1D(num_threads, thread_id, total_works, work, work_end);
826-
work_end = work + work_end;
827-
828-
int it = 0;
829-
int ia = 0;
830-
// We have to map the work beginning position to `it` and `ia` beginning position of this thread
831-
for (int work_off = 0; it < GlobalC::ucell.ntype; it++)
806+
for (int it = 0; it < GlobalC::ucell.ntype; ++it)
832807
{
833808
if (GlobalC::ucell.atoms[it].ncpp.nlcc)
834809
{
835-
if (work_off + GlobalC::ucell.atoms[it].na > work)
836-
{
837-
ia = work - work_off;
838-
break;
839-
}
840-
work_off += GlobalC::ucell.atoms[it].na;
841-
}
842-
}
843-
844-
int last_it = -1;
845-
while (work < work_end)
846-
{
847-
if (it != last_it)
848-
{
849-
// call drhoc when `it` is changed
850810
chr->non_linear_core_correction(GlobalC::ppcell.numeric,
851811
GlobalC::ucell.atoms[it].ncpp.msh,
852812
GlobalC::ucell.atoms[it].ncpp.r,
853813
GlobalC::ucell.atoms[it].ncpp.rab,
854814
GlobalC::ucell.atoms[it].ncpp.rho_atc,
855815
rhocg,
856816
rho_basis);
857-
last_it = it;
858-
}
859-
860-
// get iat form table
861-
int iat = GlobalC::ucell.itia2iat(it, ia);
862-
for (int ig = 0; ig < rho_basis->npw; ig++)
863-
{
864-
const ModuleBase::Vector3<double> gv = rho_basis->gcar[ig];
865-
const ModuleBase::Vector3<double> pos = GlobalC::ucell.atoms[it].tau[ia];
866-
const double rhocgigg = rhocg[rho_basis->ig2igg[ig]];
867-
const std::complex<double> psiv_conj = conj(psiv[ig]);
868-
869-
const double arg = ModuleBase::TWO_PI * (gv.x * pos.x + gv.y * pos.y + gv.z * pos.z);
870-
double sinp, cosp;
871-
ModuleBase::libm::sincos(arg, &sinp, &cosp);
872-
const std::complex<double> expiarg = std::complex<double>(sinp, cosp);
817+
#ifdef _OPENMP
818+
#pragma omp parallel
819+
{
820+
#endif
821+
for (int ia = 0; ia < GlobalC::ucell.atoms[it].na; ++ia)
822+
{
823+
// get iat form table
824+
int iat = GlobalC::ucell.itia2iat(it, ia);
825+
double force[3] = {0, 0, 0};
826+
#ifdef _OPENMP
827+
#pragma omp for nowait
828+
#endif
829+
for (int ig = 0; ig < rho_basis->npw; ig++)
830+
{
831+
const ModuleBase::Vector3<double> gv = rho_basis->gcar[ig];
832+
const ModuleBase::Vector3<double> pos = GlobalC::ucell.atoms[it].tau[ia];
833+
const double rhocgigg = rhocg[rho_basis->ig2igg[ig]];
834+
const std::complex<double> psiv_conj = conj(psiv[ig]);
873835

874-
auto ipol0 = GlobalC::ucell.tpiba * GlobalC::ucell.omega * rhocgigg * gv.x * psiv_conj * expiarg;
875-
forcecc(iat, 0) += ipol0.real();
836+
const double arg = ModuleBase::TWO_PI * (gv.x * pos.x + gv.y * pos.y + gv.z * pos.z);
837+
double sinp, cosp;
838+
ModuleBase::libm::sincos(arg, &sinp, &cosp);
839+
const std::complex<double> expiarg = std::complex<double>(sinp, cosp);
876840

877-
auto ipol1 = GlobalC::ucell.tpiba * GlobalC::ucell.omega * rhocgigg * gv.y * psiv_conj * expiarg;
878-
forcecc(iat, 1) += ipol1.real();
841+
auto ipol0 = GlobalC::ucell.tpiba * GlobalC::ucell.omega * rhocgigg * gv.x * psiv_conj * expiarg;
842+
force[0] += ipol0.real();
879843

880-
auto ipol2 = GlobalC::ucell.tpiba * GlobalC::ucell.omega * rhocgigg * gv.z * psiv_conj * expiarg;
881-
forcecc(iat, 2) += ipol2.real();
882-
}
844+
auto ipol1 = GlobalC::ucell.tpiba * GlobalC::ucell.omega * rhocgigg * gv.y * psiv_conj * expiarg;
845+
force[1] += ipol1.real();
883846

884-
++work;
885-
if (GlobalC::ucell.step_ia(it, &ia))
886-
{
887-
// search for next effective `it`
888-
while (!GlobalC::ucell.step_it(&it) && !GlobalC::ucell.atoms[it].ncpp.nlcc);
847+
auto ipol2 = GlobalC::ucell.tpiba * GlobalC::ucell.omega * rhocgigg * gv.z * psiv_conj * expiarg;
848+
force[2] += ipol2.real();
849+
}
850+
#ifdef _OPENMP
851+
if (omp_get_num_threads() > 1)
852+
{
853+
#pragma omp atomic
854+
forcecc(iat, 0) += force[0];
855+
#pragma omp atomic
856+
forcecc(iat, 1) += force[1];
857+
#pragma omp atomic
858+
forcecc(iat, 2) += force[2];
859+
}
860+
else
861+
#endif
862+
{
863+
forcecc(iat, 0) += force[0];
864+
forcecc(iat, 1) += force[1];
865+
forcecc(iat, 2) += force[2];
866+
}
867+
}
868+
#ifdef _OPENMP
869+
} // omp parallel
870+
#endif
889871
}
890872
}
891873

892874
delete[] rhocg;
893-
#ifdef _OPENMP
894-
} // omp parallel
895-
#endif
896875

897876
delete[] psiv; // mohan fix bug 2012-03-22
898877
Parallel_Reduce::reduce_double_pool(forcecc.c, forcecc.nr * forcecc.nc); // qianrui fix a bug for kpar > 1
@@ -1161,24 +1140,30 @@ void Forces<FPTYPE, Device>::cal_force_scc(ModuleBase::matrix& forcescc, ModuleP
11611140
if (rho_basis->gg_uniq[0] < 1.0e-8)
11621141
igg0 = 1;
11631142

1143+
double* rhocgnt = new double[rho_basis->ngg];
1144+
11641145
#ifdef _OPENMP
11651146
#pragma omp parallel
11661147
{
1148+
int num_threads = omp_get_num_threads();
1149+
int thread_id = omp_get_thread_num();
1150+
#else
1151+
int num_threads = 1;
1152+
int thread_id = 0;
11671153
#endif
11681154

11691155
// thread local work space
11701156
double *aux = new double[ndm];
11711157
ModuleBase::GlobalFunc::ZEROS(aux, ndm);
1172-
double* rhocgnt = new double[rho_basis->ngg];
1173-
ModuleBase::GlobalFunc::ZEROS(rhocgnt, rho_basis->ngg);
1158+
1159+
int ig_beg, ig_length, ig_end;
1160+
ModuleBase::TASK_DIST_1D(num_threads, thread_id, rho_basis->ngg, ig_beg, ig_length);
1161+
ModuleBase::GlobalFunc::ZEROS(rhocgnt + ig_beg, ig_length);
1162+
ig_end = ig_beg + ig_length;
11741163

11751164
double fact = 2.0;
11761165
int last_it = -1;
11771166

1178-
#ifdef _OPENMP
1179-
// use no wait to avoid syncing
1180-
#pragma omp for nowait
1181-
#endif
11821167
for (int iat = 0; iat < GlobalC::ucell.nat; ++iat)
11831168
{
11841169
int it = GlobalC::ucell.iat2it[iat];
@@ -1190,7 +1175,7 @@ void Forces<FPTYPE, Device>::cal_force_scc(ModuleBase::matrix& forcescc, ModuleP
11901175
// Here we compute the G.ne.0 term
11911176
const int mesh = GlobalC::ucell.atoms[it].ncpp.msh;
11921177

1193-
for (int ig = igg0; ig < rho_basis->ngg; ++ig)
1178+
for (int ig = std::max(ig_beg, igg0); ig < ig_end; ++ig)
11941179
{
11951180
const double gx = sqrt(rho_basis->gg_uniq[ig]) * GlobalC::ucell.tpiba;
11961181
for (int ir = 0; ir < mesh; ir++)
@@ -1210,13 +1195,20 @@ void Forces<FPTYPE, Device>::cal_force_scc(ModuleBase::matrix& forcescc, ModuleP
12101195

12111196
// record it
12121197
last_it = it;
1198+
1199+
// wait for rhocgnt update
1200+
#ifdef _OPENMP
1201+
#pragma omp barrier
1202+
#endif
12131203
}
12141204

12151205
const ModuleBase::Vector3<double> pos = GlobalC::ucell.atoms[it].tau[ia];
12161206

1217-
const auto ig_loop = [&](int ig_beg, int ig_end)
1207+
double force[3] = {0, 0, 0};
1208+
1209+
const auto ig_loop = [&](int start, int stop)
12181210
{
1219-
for (int ig = ig_beg; ig < ig_end; ++ig)
1211+
for (int ig = start; ig < stop; ++ig)
12201212
{
12211213
const ModuleBase::Vector3<double> gv = rho_basis->gcar[ig];
12221214
const double rhocgntigg = rhocgnt[GlobalC::rhopw->ig2igg[ig]];
@@ -1225,26 +1217,45 @@ void Forces<FPTYPE, Device>::cal_force_scc(ModuleBase::matrix& forcescc, ModuleP
12251217
ModuleBase::libm::sincos(arg, &sinp, &cosp);
12261218
const std::complex<double> cpm = std::complex<double>(sinp, cosp) * conj(psic[ig]);
12271219

1228-
forcescc(iat, 0) += fact * rhocgntigg * GlobalC::ucell.tpiba * gv.x * cpm.real();
1229-
forcescc(iat, 1) += fact * rhocgntigg * GlobalC::ucell.tpiba * gv.y * cpm.real();
1230-
forcescc(iat, 2) += fact * rhocgntigg * GlobalC::ucell.tpiba * gv.z * cpm.real();
1220+
force[0] += fact * rhocgntigg * GlobalC::ucell.tpiba * gv.x * cpm.real();
1221+
force[1] += fact * rhocgntigg * GlobalC::ucell.tpiba * gv.y * cpm.real();
1222+
force[2] += fact * rhocgntigg * GlobalC::ucell.tpiba * gv.z * cpm.real();
12311223
}
12321224
};
12331225

1234-
ig_loop(0, ig_gap);
1235-
ig_loop(ig_gap + 1, rho_basis->npw);
1226+
ig_loop(ig_beg, std::min(ig_gap, ig_end));
1227+
ig_loop(ig_gap + 1, ig_end);
1228+
1229+
#ifdef _OPENMP
1230+
if (num_threads > 1)
1231+
{
1232+
#pragma omp atomic
1233+
forcescc(iat, 0) += force[0];
1234+
#pragma omp atomic
1235+
forcescc(iat, 1) += force[1];
1236+
#pragma omp atomic
1237+
forcescc(iat, 2) += force[2];
1238+
}
1239+
else
1240+
#endif
1241+
{
1242+
forcescc(iat, 0) += force[0];
1243+
forcescc(iat, 1) += force[1];
1244+
forcescc(iat, 2) += force[2];
1245+
}
12361246

12371247
// std::cout << " forcescc = " << forcescc(iat,0) << " " << forcescc(iat,1) << " " <<
12381248
// forcescc(iat,2) << std::endl;
12391249
}
12401250

12411251
delete[] aux;
1242-
delete[] rhocgnt;
12431252

12441253
#ifdef _OPENMP
12451254
}
12461255
#endif
12471256

1257+
delete[] rhocgnt;
1258+
12481259
Parallel_Reduce::reduce_double_pool(forcescc.c, forcescc.nr * forcescc.nc);
12491260

12501261
delete[] psic; // mohan fix bug 2012-03-22

0 commit comments

Comments
 (0)