Skip to content

Commit a7cff68

Browse files
authored
Perf: optimize cuda version of cal_force_cc (#6450)
* refactor cal_force_ew for brevity * add cuda version of cal_force_ew * optimize cuda version of cal_force_cc * a small fix * fix bug and do little optimization
1 parent 320226d commit a7cff68

File tree

5 files changed

+108
-144
lines changed

5 files changed

+108
-144
lines changed

source/source_pw/module_pwdft/forces_cc.cpp

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -108,44 +108,48 @@ void Forces<FPTYPE, Device>::cal_force_cc(ModuleBase::matrix& forcecc,
108108
double* rhocg = new double[rho_basis->ngg];
109109
ModuleBase::GlobalFunc::ZEROS(rhocg, rho_basis->ngg);
110110

111-
std::vector<double> gv_x(rho_basis->npw);
112-
std::vector<double> gv_y(rho_basis->npw);
113-
std::vector<double> gv_z(rho_basis->npw);
111+
std::vector<double> gv_h(3 * rho_basis->npw);
112+
std::vector<double> tau_h(3 * this->nat);
114113
std::vector<double> rhocgigg_vec(rho_basis->npw);
115-
double *gv_x_d = nullptr;
116-
double *gv_y_d = nullptr;
117-
double *gv_z_d = nullptr;
114+
double *gv_d = nullptr;
115+
double *tau_d = nullptr;
118116
double *force_d = nullptr;
119117
double *rhocgigg_vec_d = nullptr;
120118
std::complex<FPTYPE>* psiv_d = nullptr;
121119
this->device = base_device::get_device_type<Device>(this->ctx);
122120

123121

124-
#ifdef _OPENMP
125-
#pragma omp parallel for
126-
#endif
127122
for (int ig = 0; ig < rho_basis->npw; ig++)
128123
{
129-
gv_x[ig] = rho_basis->gcar[ig].x;
130-
gv_y[ig] = rho_basis->gcar[ig].y;
131-
gv_z[ig] = rho_basis->gcar[ig].z;
124+
gv_h[3 * ig] = rho_basis->gcar[ig].x;
125+
gv_h[3 * ig + 1] = rho_basis->gcar[ig].y;
126+
gv_h[3 * ig + 2] = rho_basis->gcar[ig].z;
127+
}
128+
129+
for (int iat = 0; iat < this->nat; iat++)
130+
{
131+
int it = ucell_in.iat2it[iat];
132+
int ia = ucell_in.iat2ia[iat];
133+
tau_h[iat * 3] = ucell_in.atoms[it].tau[ia].x;
134+
tau_h[iat * 3 + 1] = ucell_in.atoms[it].tau[ia].y;
135+
tau_h[iat * 3 + 2] = ucell_in.atoms[it].tau[ia].z;
132136
}
133137

134138
if(this->device == base_device::GpuDevice ) {
135-
resmem_var_op()(gv_x_d, rho_basis->npw);
136-
resmem_var_op()(gv_y_d, rho_basis->npw);
137-
resmem_var_op()(gv_z_d, rho_basis->npw);
139+
resmem_var_op()(gv_d, rho_basis->npw * 3);
140+
resmem_var_op()(tau_d, this->nat * 3);
138141
resmem_var_op()(rhocgigg_vec_d, rho_basis->npw);
139142
resmem_complex_op()(psiv_d, rho_basis->nmaxgr);
140-
resmem_var_op()(force_d, 3);
143+
resmem_var_op()(force_d, 3 * this->nat);
141144

142-
syncmem_var_h2d_op()(gv_x_d, gv_x.data(), rho_basis->npw);
143-
syncmem_var_h2d_op()(gv_y_d, gv_y.data(), rho_basis->npw);
144-
syncmem_var_h2d_op()(gv_z_d, gv_z.data(), rho_basis->npw);
145+
syncmem_var_h2d_op()(gv_d, gv_h.data(), rho_basis->npw * 3);
146+
syncmem_var_h2d_op()(tau_d, tau_h.data(), this->nat * 3);
145147
syncmem_complex_h2d_op()(psiv_d, psiv, rho_basis->nmaxgr);
148+
syncmem_var_h2d_op()(force_d, forcecc.c, 3 * this->nat);
146149
}
147150

148-
151+
double* tau_it_d = tau_d; // the start address of each atom type's tau
152+
double* force_it_d = force_d;
149153
for (int it = 0; it < ucell_in.ntype; ++it)
150154
{
151155
if (ucell_in.atoms[it].ncpp.nlcc)
@@ -166,10 +170,7 @@ void Forces<FPTYPE, Device>::cal_force_cc(ModuleBase::matrix& forcecc,
166170
rho_basis,
167171
1,
168172
ucell_in);
169-
170-
#ifdef _OPENMP
171-
#pragma omp parallel for
172-
#endif
173+
173174
for (int ig = 0; ig < rho_basis->npw; ig++)
174175
{
175176
rhocgigg_vec[ig] = rhocg[rho_basis->ig2igg[ig]];
@@ -178,42 +179,53 @@ void Forces<FPTYPE, Device>::cal_force_cc(ModuleBase::matrix& forcecc,
178179
if(this->device == base_device::GpuDevice ) {
179180
syncmem_var_h2d_op()(rhocgigg_vec_d, rhocgigg_vec.data(), rho_basis->npw);
180181
}
181-
for (int ia = 0; ia < ucell_in.atoms[it].na; ++ia)
182-
{
183-
const ModuleBase::Vector3<double> pos = ucell_in.atoms[it].tau[ia];
184-
// get iat form table
185-
int iat = ucell_in.itia2iat(it, ia);
186-
double force[3] = {0, 0, 0};
187-
188-
if(this->device == base_device::GpuDevice ) {
189-
syncmem_var_h2d_op()(force_d, force, 3);
190-
hamilt::cal_force_npw_op<FPTYPE, Device>()(
191-
psiv_d, gv_x_d, gv_y_d, gv_z_d, rhocgigg_vec_d, force_d, pos.x, pos.y, pos.z,
192-
rho_basis->npw, ucell_in.omega, ucell_in.tpiba
193-
);
194-
syncmem_var_d2h_op()(force, force_d, 3);
195-
196-
} else {
197-
hamilt::cal_force_npw_op<FPTYPE, Device>()(
198-
psiv, gv_x.data(), gv_y.data(), gv_z.data(), rhocgigg_vec.data(), force, pos.x, pos.y, pos.z,
199-
rho_basis->npw, ucell_in.omega, ucell_in.tpiba
200-
);
201-
}
202182

183+
if(this->device == base_device::GpuDevice ) {
184+
hamilt::cal_force_npw_op<FPTYPE, Device>()(
185+
psiv_d, gv_d, rhocgigg_vec_d, force_it_d, tau_it_d,
186+
rho_basis->npw, ucell_in.omega, ucell_in.tpiba, ucell_in.atoms[it].na
187+
);
188+
} else {
189+
#pragma omp for
190+
for(int ia = 0; ia < ucell_in.atoms[it].na; ia++)
203191
{
204-
forcecc(iat, 0) += force[0];
205-
forcecc(iat, 1) += force[1];
206-
forcecc(iat, 2) += force[2];
192+
double fx = 0.0, fy = 0.0, fz = 0.0;
193+
int iat = ucell_in.itia2iat(it, ia);
194+
for (int ig = 0; ig < rho_basis->npw; ig++)
195+
{
196+
const std::complex<double> psiv_conj = conj(psiv[ig]);
197+
198+
const double arg = ModuleBase::TWO_PI * (gv_h[ig * 3] * tau_h[iat * 3]
199+
+ gv_h[ig * 3 + 1] * tau_h[iat * 3 + 1] + gv_h[ig * 3 + 2] * tau_h[iat * 3 + 2]);
200+
double sinp, cosp;
201+
ModuleBase::libm::sincos(arg, &sinp, &cosp);
202+
const std::complex<double> expiarg = std::complex<double>(sinp, cosp);
203+
204+
const std::complex<double> tmp_var = psiv_conj * expiarg * ucell_in.tpiba * ucell_in.omega * rhocgigg_vec[ig];
205+
206+
const std::complex<double> ipol0 = tmp_var * gv_h[ig * 3];
207+
fx += ipol0.real();
208+
209+
const std::complex<double> ipol1 = tmp_var * gv_h[ig * 3 + 1];
210+
fy += ipol1.real();
211+
212+
const std::complex<double> ipol2 = tmp_var * gv_h[ig * 3 + 2];
213+
fz += ipol2.real();
214+
}
215+
forcecc(iat, 0) += fx;
216+
forcecc(iat, 1) += fy;
217+
forcecc(iat, 2) += fz;
207218
}
208219
}
209-
210220
}
221+
tau_it_d += 3 * ucell_in.atoms[it].na; // update the start address of each atom type's tau
222+
force_it_d += 3 * ucell_in.atoms[it].na;
211223
}
212-
if (this->device == base_device::GpuDevice)
224+
if(this->device == base_device::GpuDevice)
213225
{
214-
delmem_var_op()(gv_x_d);
215-
delmem_var_op()(gv_y_d);
216-
delmem_var_op()(gv_z_d);
226+
syncmem_var_d2h_op()(forcecc.c, force_d, 3 * nat);
227+
delmem_var_op()(gv_d);
228+
delmem_var_op()(tau_d);
217229
delmem_var_op()(force_d);
218230
delmem_var_op()(rhocgigg_vec_d);
219231
delmem_complex_op()(psiv_d);

source/source_pw/module_pwdft/kernels/cuda/stress_op.cu

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#include "source_pw/module_pwdft/kernels/stress_op.h"
2+
#include "source_base/constants.h"
3+
#include "source_base/module_device/device.h"
24
#include "vnl_tools_cu.hpp"
35
#include "source_base/module_device/types.h"
46

57
#include <complex>
68
#include <thrust/complex.h>
79
#include <base/macros/macros.h>
8-
#include <source_base/module_device/device.h>
910

1011
#include <cuda_runtime.h>
1112

@@ -703,43 +704,51 @@ __global__ void cal_stress_drhoc_aux3(
703704
template <typename FPTYPE>
704705
__global__ void cal_force_npw(
705706
const thrust::complex<FPTYPE> *psiv,
706-
const FPTYPE* gv_x, const FPTYPE* gv_y, const FPTYPE* gv_z,
707+
const FPTYPE* gv,
707708
const FPTYPE* rhocgigg_vec,
708709
FPTYPE* force,
709-
const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_z,
710+
const FPTYPE* tau,
710711
const int npw,
711712
const FPTYPE omega, const FPTYPE tpiba
712713
){
713-
const double TWO_PI = 2.0 * 3.14159265358979323846;
714-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
715-
int begin_idx = tid * 1024;
716-
if(begin_idx > npw) return;
714+
int ia = blockIdx.x;
715+
int tid = threadIdx.x;
716+
if(tid > npw) return;
717717

718+
FPTYPE pos_x = tau[ia * 3];
719+
FPTYPE pos_y = tau[ia * 3 + 1];
720+
FPTYPE pos_z = tau[ia * 3 + 2];
718721
FPTYPE t_force0 = 0;
719722
FPTYPE t_force1 = 0;
720723
FPTYPE t_force2 = 0;
721-
for(int ig = begin_idx; ig<begin_idx+1024 && ig<npw;ig++) {
724+
for(int ig = tid; ig<npw;ig += blockDim.x) {
722725
const thrust::complex<FPTYPE> psiv_conj = conj(psiv[ig]);
723726

724-
const FPTYPE arg = TWO_PI * (gv_x[ig] * pos_x + gv_y[ig] * pos_y + gv_z[ig] * pos_z);
725-
const FPTYPE sinp = sin(arg);
726-
const FPTYPE cosp = cos(arg);
727+
const FPTYPE arg = ModuleBase::TWO_PI * (gv[ig * 3] * pos_x + gv[ig * 3 + 1] * pos_y + gv[ig * 3 + 2] * pos_z);
728+
FPTYPE sinp, cosp;
729+
sincos(arg, &sinp, &cosp);
727730
const thrust::complex<FPTYPE> expiarg = thrust::complex<FPTYPE>(sinp, cosp);
728731

729732
const thrust::complex<FPTYPE> tmp_var = psiv_conj * expiarg * tpiba * omega * rhocgigg_vec[ig];
730733

731-
const thrust::complex<FPTYPE> ipol0 = tmp_var * gv_x[ig];
734+
const thrust::complex<FPTYPE> ipol0 = tmp_var * gv[ig * 3];
732735
t_force0 += ipol0.real();
733736

734-
const thrust::complex<FPTYPE> ipol1 = tmp_var * gv_y[ig];
737+
const thrust::complex<FPTYPE> ipol1 = tmp_var * gv[ig * 3 + 1];
735738
t_force1 += ipol1.real();
736739

737-
const thrust::complex<FPTYPE> ipol2 = tmp_var * gv_z[ig];
740+
const thrust::complex<FPTYPE> ipol2 = tmp_var * gv[ig * 3 + 2];
738741
t_force2 += ipol2.real();
739742
}
740-
atomicAdd(&force[0], t_force0);
741-
atomicAdd(&force[1], t_force1);
742-
atomicAdd(&force[2], t_force2);
743+
__syncwarp();
744+
warp_reduce(t_force0);
745+
warp_reduce(t_force1);
746+
warp_reduce(t_force2);
747+
if (threadIdx.x % WARP_SIZE == 0) {
748+
atomicAdd(&force[ia * 3], t_force0);
749+
atomicAdd(&force[ia * 3 + 1], t_force1);
750+
atomicAdd(&force[ia * 3 + 2], t_force2);
751+
}
743752
}
744753

745754
template <typename FPTYPE>
@@ -880,22 +889,17 @@ void cal_stress_drhoc_aux_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
880889
template <typename FPTYPE>
881890
void cal_force_npw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
882891
const std::complex<FPTYPE> *psiv,
883-
const FPTYPE* gv_x, const FPTYPE* gv_y, const FPTYPE* gv_z,
892+
const FPTYPE* gv,
884893
const FPTYPE* rhocgigg_vec,
885894
FPTYPE* force,
886-
const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_z,
895+
const FPTYPE* tau,
887896
const int npw,
888-
const FPTYPE omega, const FPTYPE tpiba
897+
const FPTYPE omega, const FPTYPE tpiba, const int na
889898
)
890899
{
891-
// Divide the npw size range into blocksize 1024 blocks
892-
int t_size = 1024;
893-
int t_num = (npw%t_size) ? (npw/t_size + 1) : (npw/t_size);
894-
dim3 npwgrid(((t_num%THREADS_PER_BLOCK) ? (t_num/THREADS_PER_BLOCK + 1) : (t_num/THREADS_PER_BLOCK)));
895-
896-
cal_force_npw <<< npwgrid, THREADS_PER_BLOCK >>> (
900+
cal_force_npw <<<na, THREADS_PER_BLOCK >>> (
897901
reinterpret_cast<const thrust::complex<FPTYPE>*>(psiv),
898-
gv_x, gv_y, gv_z, rhocgigg_vec, force, pos_x, pos_y, pos_z,
902+
gv, rhocgigg_vec, force, tau,
899903
npw, omega, tpiba
900904
);
901905
return ;

source/source_pw/module_pwdft/kernels/rocm/stress_op.hip.cu

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -844,23 +844,17 @@ void cal_stress_drhoc_aux_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
844844
template <typename FPTYPE>
845845
void cal_force_npw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
846846
const std::complex<FPTYPE> *psiv,
847-
const FPTYPE* gv_x, const FPTYPE* gv_y, const FPTYPE* gv_z,
847+
const FPTYPE* gv,
848848
const FPTYPE* rhocgigg_vec,
849849
FPTYPE* force,
850-
const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_z,
850+
const FPTYPE* tau,
851851
const int npw,
852-
const FPTYPE omega, const FPTYPE tpiba
852+
const FPTYPE omega, const FPTYPE tpiba, const int na
853853
)
854854
{
855-
int t_size = 1024;
856-
int t_num = (npw%t_size) ? (npw/t_size + 1) : (npw/t_size);
857-
858-
dim3 npwgrid(((t_num%THREADS_PER_BLOCK) ? (t_num/THREADS_PER_BLOCK + 1) : (t_num/THREADS_PER_BLOCK)));
859-
860-
861-
hipLaunchKernelGGL(HIP_KERNEL_NAME(cal_force_npw<FPTYPE>), npwgrid, THREADS_PER_BLOCK,0,0,
855+
hipLaunchKernelGGL(HIP_KERNEL_NAME(cal_force_npw<FPTYPE>), na, THREADS_PER_BLOCK,0,0,
862856
reinterpret_cast<const thrust::complex<FPTYPE>*>(psiv),
863-
gv_x, gv_y, gv_z, rhocgigg_vec, force, pos_x, pos_y, pos_z,
857+
gv, rhocgigg_vec, force, tau,
864858
npw, omega, tpiba
865859
);
866860

source/source_pw/module_pwdft/kernels/stress_op.cpp

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -629,49 +629,6 @@ struct cal_stress_drhoc_aux_op<FPTYPE, base_device::DEVICE_CPU> {
629629
}
630630
};
631631

632-
633-
template <typename FPTYPE>
634-
struct cal_force_npw_op<FPTYPE, base_device::DEVICE_CPU> {
635-
void operator()(const std::complex<FPTYPE>* psiv,
636-
const FPTYPE* gv_x,
637-
const FPTYPE* gv_y,
638-
const FPTYPE* gv_z,
639-
const FPTYPE* rhocgigg_vec,
640-
FPTYPE* force,
641-
const FPTYPE pos_x,
642-
const FPTYPE pos_y,
643-
const FPTYPE pos_z,
644-
const int npw,
645-
const FPTYPE omega,
646-
const FPTYPE tpiba)
647-
{
648-
649-
#ifdef _OPENMP
650-
#pragma omp for nowait
651-
#endif
652-
for (int ig = 0; ig < npw; ig++)
653-
{
654-
const std::complex<FPTYPE> psiv_conj = conj(psiv[ig]);
655-
656-
const FPTYPE arg = ModuleBase::TWO_PI * (gv_x[ig] * pos_x + gv_y[ig] * pos_y + gv_z[ig] * pos_z);
657-
FPTYPE sinp, cosp;
658-
ModuleBase::libm::sincos(arg, &sinp, &cosp);
659-
const std::complex<FPTYPE> expiarg = std::complex<FPTYPE>(sinp, cosp);
660-
661-
const std::complex<FPTYPE> tmp_var = psiv_conj * expiarg * tpiba * omega * rhocgigg_vec[ig];
662-
663-
const std::complex<FPTYPE> ipol0 = tmp_var * gv_x[ig];
664-
force[0] += ipol0.real();
665-
666-
const std::complex<FPTYPE> ipol1 = tmp_var * gv_y[ig];
667-
force[1] += ipol1.real();
668-
669-
const std::complex<FPTYPE> ipol2 = tmp_var * gv_z[ig];
670-
force[2] += ipol2.real();
671-
}
672-
}
673-
};
674-
675632
template <typename FPTYPE>
676633
struct cal_multi_dot_op<FPTYPE, base_device::DEVICE_CPU> {
677634
FPTYPE operator()(const int& npw,
@@ -768,9 +725,6 @@ template struct cal_vq_deri_op<double, base_device::DEVICE_CPU>;
768725
template struct cal_stress_drhoc_aux_op<float, base_device::DEVICE_CPU>;
769726
template struct cal_stress_drhoc_aux_op<double, base_device::DEVICE_CPU>;
770727

771-
template struct cal_force_npw_op<float, base_device::DEVICE_CPU>;
772-
template struct cal_force_npw_op<double, base_device::DEVICE_CPU>;
773-
774728
template struct cal_multi_dot_op<float, base_device::DEVICE_CPU>;
775729
template struct cal_multi_dot_op<double, base_device::DEVICE_CPU>;
776730

source/source_pw/module_pwdft/kernels/stress_op.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,13 @@ struct cal_stress_drhoc_aux_op{
241241
template <typename FPTYPE, typename Device>
242242
struct cal_force_npw_op{
243243
void operator()(const std::complex<FPTYPE> *psiv,
244-
const FPTYPE* gv_x, const FPTYPE* gv_y, const FPTYPE* gv_z,
244+
const FPTYPE* gv,
245245
const FPTYPE* rhocgigg_vec,
246246
FPTYPE* force,
247-
const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_xz,
247+
const FPTYPE* tau_x,
248248
const int npw,
249-
const FPTYPE omega, const FPTYPE tpiba
250-
);
249+
const FPTYPE omega, const FPTYPE tpiba, const int na
250+
) {}
251251
};
252252

253253
template <typename FPTYPE, typename Device>
@@ -480,12 +480,12 @@ struct cal_stress_drhoc_aux_op<FPTYPE, base_device::DEVICE_GPU>{
480480
template <typename FPTYPE>
481481
struct cal_force_npw_op<FPTYPE, base_device::DEVICE_GPU>{
482482
void operator()(const std::complex<FPTYPE> *psiv,
483-
const FPTYPE* gv_x, const FPTYPE* gv_y, const FPTYPE* gv_z,
483+
const FPTYPE* gv,
484484
const FPTYPE* rhocgigg_vec,
485485
FPTYPE* force,
486-
const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_xz,
486+
const FPTYPE* tau,
487487
const int npw,
488-
const FPTYPE omega, const FPTYPE tpiba
488+
const FPTYPE omega, const FPTYPE tpiba, const int na
489489
);
490490
};
491491

0 commit comments

Comments
 (0)