Skip to content

Commit acaa1fc

Browse files
authored
Perf: add cuda support for cal_force_loc (#6427)
* add cuda support for cal_force_loc * fix compilation bug * fix bug
1 parent 940a664 commit acaa1fc

File tree

3 files changed

+261
-21
lines changed

3 files changed

+261
-21
lines changed

source/source_pw/module_pwdft/forces.cpp

Lines changed: 95 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
329329
{
330330
ModuleBase::TITLE("Forces", "cal_force_loc");
331331
ModuleBase::timer::tick("Forces", "cal_force_loc");
332-
332+
this->device = base_device::get_device_type<Device>(this->ctx);
333333
std::complex<double>* aux = new std::complex<double>[rho_basis->nmaxgr];
334334
// now, in all pools , the charge are the same,
335335
// so, the force calculated by each pool is equal.
@@ -368,30 +368,105 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
368368
// to G space. maybe need fftw with OpenMP
369369
rho_basis->real2recip(aux, aux);
370370

371-
#ifdef _OPENMP
372-
#pragma omp parallel for
373-
#endif
374-
for (int iat = 0; iat < this->nat; ++iat)
371+
std::vector<double> tau_h;
372+
std::vector<double> gcar_h;
373+
if(this->device == base_device::GpuDevice)
375374
{
376-
// read `it` `ia` from the table
377-
int it = ucell.iat2it[iat];
378-
int ia = ucell.iat2ia[iat];
379-
for (int ig = 0; ig < rho_basis->npw; ig++)
375+
tau_h.resize(this->nat * 3);
376+
for(int iat = 0; iat < this->nat; ++iat)
380377
{
381-
const double phase = ModuleBase::TWO_PI * (rho_basis->gcar[ig] * ucell.atoms[it].tau[ia]);
382-
double sinp, cosp;
383-
ModuleBase::libm::sincos(phase, &sinp, &cosp);
384-
const double factor
385-
= vloc(it, rho_basis->ig2igg[ig]) * (cosp * aux[ig].imag() + sinp * aux[ig].real());
386-
forcelc(iat, 0) += rho_basis->gcar[ig][0] * factor;
387-
forcelc(iat, 1) += rho_basis->gcar[ig][1] * factor;
388-
forcelc(iat, 2) += rho_basis->gcar[ig][2] * factor;
378+
int it = ucell.iat2it[iat];
379+
int ia = ucell.iat2ia[iat];
380+
tau_h[iat * 3] = ucell.atoms[it].tau[ia].x;
381+
tau_h[iat * 3 + 1] = ucell.atoms[it].tau[ia].y;
382+
tau_h[iat * 3 + 2] = ucell.atoms[it].tau[ia].z;
389383
}
390-
forcelc(iat, 0) *= (ucell.tpiba * ucell.omega);
391-
forcelc(iat, 1) *= (ucell.tpiba * ucell.omega);
392-
forcelc(iat, 2) *= (ucell.tpiba * ucell.omega);
384+
385+
gcar_h.resize(rho_basis->npw * 3);
386+
for(int ig = 0; ig < rho_basis->npw; ++ig)
387+
{
388+
gcar_h[ig * 3] = rho_basis->gcar[ig].x;
389+
gcar_h[ig * 3 + 1] = rho_basis->gcar[ig].y;
390+
gcar_h[ig * 3 + 2] = rho_basis->gcar[ig].z;
391+
}
392+
}
393+
int* iat2it_d = nullptr;
394+
int* ig2gg_d = nullptr;
395+
double* gcar_d = nullptr;
396+
double* tau_d = nullptr;
397+
std::complex<double>* aux_d = nullptr;
398+
double* forcelc_d = nullptr;
399+
double* vloc_d = nullptr;
400+
if(this->device == base_device::GpuDevice)
401+
{
402+
resmem_int_op()(iat2it_d, this->nat);
403+
resmem_int_op()(ig2gg_d, rho_basis->npw);
404+
resmem_var_op()(gcar_d, rho_basis->npw * 3);
405+
resmem_var_op()(tau_d, this->nat * 3);
406+
resmem_complex_op()(aux_d, rho_basis->npw);
407+
resmem_var_op()(forcelc_d, this->nat * 3);
408+
resmem_var_op()(vloc_d, vloc.nr * vloc.nc);
409+
410+
syncmem_int_h2d_op()(iat2it_d, ucell.iat2it, this->nat);
411+
syncmem_int_h2d_op()(ig2gg_d, rho_basis->ig2igg, rho_basis->npw);
412+
syncmem_var_h2d_op()(gcar_d, gcar_h.data(), rho_basis->npw * 3);
413+
syncmem_var_h2d_op()(tau_d, tau_h.data(), this->nat * 3);
414+
syncmem_complex_h2d_op()(aux_d, aux, rho_basis->npw);
415+
syncmem_var_h2d_op()(forcelc_d, forcelc.c, this->nat * 3);
416+
syncmem_var_h2d_op()(vloc_d, vloc.c, vloc.nr * vloc.nc);
393417
}
394418

419+
if(this->device == base_device::GpuDevice)
420+
{
421+
hamilt::cal_force_loc_op<FPTYPE, Device>()(
422+
this->nat,
423+
rho_basis->npw,
424+
ucell.tpiba * ucell.omega,
425+
iat2it_d,
426+
ig2gg_d,
427+
gcar_d,
428+
tau_d,
429+
aux_d,
430+
vloc_d,
431+
vloc.nc,
432+
forcelc_d);
433+
syncmem_var_d2h_op()(forcelc.c, forcelc_d, this->nat * 3);
434+
}
435+
else{
436+
#ifdef _OPENMP
437+
#pragma omp parallel for
438+
#endif
439+
for (int iat = 0; iat < this->nat; ++iat)
440+
{
441+
// read `it` `ia` from the table
442+
int it = ucell.iat2it[iat];
443+
int ia = ucell.iat2ia[iat];
444+
for (int ig = 0; ig < rho_basis->npw; ig++)
445+
{
446+
const double phase = ModuleBase::TWO_PI * (rho_basis->gcar[ig] * ucell.atoms[it].tau[ia]);
447+
double sinp, cosp;
448+
ModuleBase::libm::sincos(phase, &sinp, &cosp);
449+
const double factor
450+
= vloc(it, rho_basis->ig2igg[ig]) * (cosp * aux[ig].imag() + sinp * aux[ig].real());
451+
forcelc(iat, 0) += rho_basis->gcar[ig][0] * factor;
452+
forcelc(iat, 1) += rho_basis->gcar[ig][1] * factor;
453+
forcelc(iat, 2) += rho_basis->gcar[ig][2] * factor;
454+
}
455+
forcelc(iat, 0) *= (ucell.tpiba * ucell.omega);
456+
forcelc(iat, 1) *= (ucell.tpiba * ucell.omega);
457+
forcelc(iat, 2) *= (ucell.tpiba * ucell.omega);
458+
}
459+
}
460+
if(this->device == base_device::GpuDevice)
461+
{
462+
delmem_int_op()(iat2it_d);
463+
delmem_int_op()(ig2gg_d);
464+
delmem_var_op()(gcar_d);
465+
delmem_var_op()(tau_d);
466+
delmem_complex_op()(aux_d);
467+
delmem_var_op()(forcelc_d);
468+
delmem_var_op()(vloc_d);
469+
}
395470
// this->print(GlobalV::ofs_running, "local forces", forcelc);
396471
Parallel_Reduce::reduce_pool(forcelc.c, forcelc.nr * forcelc.nc);
397472
delete[] aux;

source/source_pw/module_pwdft/kernels/cuda/force_op.cu

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "source_pw/module_pwdft/kernels/force_op.h"
22
// #include "source_psi/kernels/device.h"
33
#include "source_base/module_device/types.h"
4+
#include "source_base/constants.h"
45

56
#include <complex>
67

@@ -10,9 +11,19 @@
1011
#include <source_base/module_device/device.h>
1112

1213
#define THREADS_PER_BLOCK 256
14+
#define FULL_MASK 0xffffffff
15+
#define WARP_SIZE 32
1316

1417
namespace hamilt {
1518

19+
template <typename FPTYPE>
20+
__forceinline__
21+
__device__
22+
void warp_reduce(FPTYPE & val) {
23+
for (int offset = 16; offset > 0; offset >>= 1) {
24+
val += __shfl_down_sync(FULL_MASK, val, offset);
25+
}
26+
}
1627

1728
template <typename FPTYPE>
1829
__global__ void cal_vkb1_nl(
@@ -606,15 +617,139 @@ void revertVkbValues(
606617
cudaCheckOnDebug();
607618
}
608619

620+
template <typename FPTYPE>
621+
__global__ void force_loc_kernel(
622+
const int nat,
623+
const int npw,
624+
const FPTYPE tpiba_omega,
625+
const int* iat2it,
626+
const int* ig2gg_d,
627+
const FPTYPE* gcar_d,
628+
const FPTYPE* tau_d,
629+
const thrust::complex<FPTYPE>* aux_d,
630+
const FPTYPE* vloc_d,
631+
const int vloc_nc,
632+
FPTYPE* forcelc_d)
633+
{
634+
const int iat = blockIdx.x;
635+
const int tid = threadIdx.x;
636+
const int warp_id = tid / WARP_SIZE;
637+
const int lane_id = tid % WARP_SIZE;
638+
639+
if (iat >= nat) return;
640+
const int it = iat2it[iat]; // get the type of atom
641+
642+
// Initialize force components
643+
FPTYPE force_x = 0.0;
644+
FPTYPE force_y = 0.0;
645+
FPTYPE force_z = 0.0;
646+
647+
const auto tau_x = tau_d[iat * 3 + 0];
648+
const auto tau_y = tau_d[iat * 3 + 1];
649+
const auto tau_z = tau_d[iat * 3 + 2];
650+
651+
// Process all plane waves in chunks of blockDim.x
652+
for (int ig = tid; ig < npw; ig += blockDim.x) {
653+
const auto gcar_x = gcar_d[ig * 3 + 0];
654+
const auto gcar_y = gcar_d[ig * 3 + 1];
655+
const auto gcar_z = gcar_d[ig * 3 + 2];
656+
657+
// Calculate phase factor
658+
const FPTYPE phase = ModuleBase::TWO_PI * (gcar_x * tau_x +
659+
gcar_y * tau_y +
660+
gcar_z * tau_z);
661+
FPTYPE sinp, cosp;
662+
sincos(phase, &sinp, &cosp);
663+
664+
// Get vloc value
665+
const FPTYPE vloc_val = vloc_d[it * vloc_nc + ig2gg_d[ig]];
666+
667+
// Calculate factor
668+
const auto aux_val = aux_d[ig];
669+
const FPTYPE factor = vloc_val * (cosp * aux_val.imag() + sinp * aux_val.real());
670+
671+
// Multiply by gcar components
672+
force_x += gcar_x * factor;
673+
force_y += gcar_y * factor;
674+
force_z += gcar_z * factor;
675+
}
676+
677+
// Warp-level reduction
678+
warp_reduce<FPTYPE>(force_x);
679+
warp_reduce<FPTYPE>(force_y);
680+
warp_reduce<FPTYPE>(force_z);
681+
682+
// First thread in each warp writes to shared memory
683+
__shared__ FPTYPE warp_sums_x[THREADS_PER_BLOCK / WARP_SIZE]; // 256 threads / 32 = 8 warps
684+
__shared__ FPTYPE warp_sums_y[THREADS_PER_BLOCK / WARP_SIZE];
685+
__shared__ FPTYPE warp_sums_z[THREADS_PER_BLOCK / WARP_SIZE];
686+
687+
if (lane_id == 0) {
688+
warp_sums_x[warp_id] = force_x;
689+
warp_sums_y[warp_id] = force_y;
690+
warp_sums_z[warp_id] = force_z;
691+
}
692+
693+
__syncthreads();
694+
695+
// Final reduction by first warp
696+
if (warp_id == 0) {
697+
FPTYPE final_x = (lane_id < blockDim.x/WARP_SIZE) ? warp_sums_x[lane_id] : 0.0;
698+
FPTYPE final_y = (lane_id < blockDim.x/WARP_SIZE) ? warp_sums_y[lane_id] : 0.0;
699+
FPTYPE final_z = (lane_id < blockDim.x/WARP_SIZE) ? warp_sums_z[lane_id] : 0.0;
700+
701+
warp_reduce<FPTYPE>(final_x);
702+
warp_reduce<FPTYPE>(final_y);
703+
warp_reduce<FPTYPE>(final_z);
704+
705+
if (lane_id == 0) {
706+
forcelc_d[iat * 3 + 0] = final_x * tpiba_omega;
707+
forcelc_d[iat * 3 + 1] = final_y * tpiba_omega;
708+
forcelc_d[iat * 3 + 2] = final_z * tpiba_omega;
709+
}
710+
}
711+
}
712+
713+
template <typename FPTYPE>
714+
void cal_force_loc_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
715+
const int nat,
716+
const int npw,
717+
const FPTYPE tpiba_omega,
718+
const int* iat2it,
719+
const int* ig2igg,
720+
const FPTYPE* gcar,
721+
const FPTYPE* tau,
722+
const std::complex<FPTYPE>* aux,
723+
const FPTYPE* vloc,
724+
const int vloc_nc,
725+
FPTYPE* forcelc)
726+
{
727+
force_loc_kernel<FPTYPE>
728+
<<<nat, THREADS_PER_BLOCK>>>(nat,
729+
npw,
730+
tpiba_omega,
731+
iat2it,
732+
ig2igg,
733+
gcar,
734+
tau,
735+
reinterpret_cast<const thrust::complex<FPTYPE>*>(aux),
736+
vloc,
737+
vloc_nc,
738+
forcelc); // array of data
739+
740+
}
741+
742+
609743
// for revertVkbValues functions instantiation
610744
template void revertVkbValues<double>(const int *gcar_zero_ptrs, std::complex<double> *vkb_ptr, const std::complex<double> *vkb_save_ptr, int nkb, int gcar_zero_count, int npw, int ipol, int npwx, const std::complex<double> coeff);
611745
// for saveVkbValues functions instantiation
612746
template void saveVkbValues<double>(const int *gcar_zero_ptrs, const std::complex<double> *vkb_ptr, std::complex<double> *vkb_save_ptr, int nkb, int gcar_zero_count, int npw, int ipol, int npwx);
613747

614748
template struct cal_vkb1_nl_op<float, base_device::DEVICE_GPU>;
615749
template struct cal_force_nl_op<float, base_device::DEVICE_GPU>;
750+
template struct cal_force_loc_op<float, base_device::DEVICE_GPU>;
616751

617752
template struct cal_vkb1_nl_op<double, base_device::DEVICE_GPU>;
618753
template struct cal_force_nl_op<double, base_device::DEVICE_GPU>;
619-
754+
template struct cal_force_loc_op<double, base_device::DEVICE_GPU>;
620755
} // namespace hamilt

source/source_pw/module_pwdft/kernels/force_op.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,21 @@ struct cal_force_nl_op
149149
FPTYPE* force);
150150
};
151151

152+
template <typename FPTYPE, typename Device>
153+
struct cal_force_loc_op{
154+
void operator()(
155+
const int nat,
156+
const int npw,
157+
const FPTYPE tpiba_omega,
158+
const int* iat2it,
159+
const int* ig2igg,
160+
const FPTYPE* gcar,
161+
const FPTYPE* tau,
162+
const std::complex<FPTYPE>* aux,
163+
const FPTYPE* vloc,
164+
const int vloc_nr,
165+
FPTYPE* forcelc) {};
166+
};
152167
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
153168
template <typename FPTYPE>
154169
struct cal_vkb1_nl_op<FPTYPE, base_device::DEVICE_GPU>
@@ -275,6 +290,21 @@ void saveVkbValues(const int* gcar_zero_ptrs,
275290
int ipol,
276291
int npwx);
277292

293+
template <typename FPTYPE>
294+
struct cal_force_loc_op<FPTYPE, base_device::DEVICE_GPU>{
295+
void operator()(
296+
const int nat,
297+
const int npw,
298+
const FPTYPE tpiba_omega,
299+
const int* iat2it,
300+
const int* ig2igg,
301+
const FPTYPE* gcar,
302+
const FPTYPE* tau,
303+
const std::complex<FPTYPE>* aux,
304+
const FPTYPE* vloc,
305+
const int vloc_nr,
306+
FPTYPE* forcelc);
307+
};
278308
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
279309
} // namespace hamilt
280310
#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H

0 commit comments

Comments
 (0)