Skip to content

Commit 6148d7b

Browse files
committed
add stress_ewa_op for cpu
1 parent 808af53 commit 6148d7b

File tree

3 files changed

+151
-15
lines changed

3 files changed

+151
-15
lines changed

source/module_hamilt_pw/hamilt_pwdft/kernels/stress_op.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,65 @@ template struct cal_force_npw_op<double, base_device::DEVICE_CPU>;
759759
template struct cal_multi_dot_op<float, base_device::DEVICE_CPU>;
760760
template struct cal_multi_dot_op<double, base_device::DEVICE_CPU>;
761761

762+
// CPU implementation of Ewald stress sincos operator
763+
template <typename FPTYPE>
764+
struct cal_stress_ewa_sincos_op<FPTYPE, base_device::DEVICE_CPU>
765+
{
766+
void operator()(const base_device::DEVICE_CPU* ctx,
767+
const int& nat,
768+
const int& npw,
769+
const int& ig_gge0,
770+
const FPTYPE* gcar,
771+
const FPTYPE* tau,
772+
const FPTYPE* zv_facts,
773+
FPTYPE* rhostar_real,
774+
FPTYPE* rhostar_imag)
775+
{
776+
const FPTYPE TWO_PI = 2.0 * M_PI;
777+
778+
// Initialize output arrays
779+
std::fill(rhostar_real, rhostar_real + npw, static_cast<FPTYPE>(0.0));
780+
std::fill(rhostar_imag, rhostar_imag + npw, static_cast<FPTYPE>(0.0));
781+
782+
#ifdef _OPENMP
783+
#pragma omp parallel for
784+
#endif
785+
for (int ig = 0; ig < npw; ig++) {
786+
if (ig == ig_gge0) continue; // Skip G=0
787+
788+
FPTYPE local_rhostar_real = 0.0;
789+
FPTYPE local_rhostar_imag = 0.0;
790+
791+
// Double loop: iat -> ig (as requested)
792+
for (int iat = 0; iat < nat; iat++) {
793+
const FPTYPE tau_x = tau[iat * 3 + 0];
794+
const FPTYPE tau_y = tau[iat * 3 + 1];
795+
const FPTYPE tau_z = tau[iat * 3 + 2];
796+
const FPTYPE zv = zv_facts[iat];
797+
798+
// Calculate phase: 2π * (G · τ) - similar to cal_force_ewa phase
799+
const FPTYPE phase = TWO_PI * (gcar[ig * 3 + 0] * tau_x +
800+
gcar[ig * 3 + 1] * tau_y +
801+
gcar[ig * 3 + 2] * tau_z);
802+
803+
// Calculate sincos
804+
FPTYPE sinp, cosp;
805+
ModuleBase::libm::sincos(phase, &sinp, &cosp);
806+
807+
// Accumulate structure factor
808+
local_rhostar_real += zv * cosp;
809+
local_rhostar_imag += zv * sinp;
810+
}
811+
812+
// Store results
813+
rhostar_real[ig] = local_rhostar_real;
814+
rhostar_imag[ig] = local_rhostar_imag;
815+
}
816+
}
817+
};
818+
819+
template struct cal_stress_ewa_sincos_op<float, base_device::DEVICE_CPU>;
820+
template struct cal_stress_ewa_sincos_op<double, base_device::DEVICE_CPU>;
762821

763822
// template struct prepare_vkb_deri_ptr_op<float, base_device::DEVICE_CPU>;
764823
// template struct prepare_vkb_deri_ptr_op<double, base_device::DEVICE_CPU>;

source/module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,35 @@ struct cal_multi_dot_op{
260260
const std::complex<FPTYPE>* psi);
261261
};
262262

263+
template <typename FPTYPE, typename Device>
264+
struct cal_stress_ewa_sincos_op
265+
{
266+
/// @brief Calculate Ewald stress sincos computation
267+
/// Only computes rhostar, other calculations remain in original function
268+
///
269+
/// Input Parameters
270+
/// @param ctx - which device this function runs on
271+
/// @param nat - total number of atoms
272+
/// @param npw - number of plane waves
273+
/// @param ig_gge0 - index of G=0 vector (-1 if not present)
274+
/// @param gcar - G-vector Cartesian coordinates [npw * 3]
275+
/// @param tau - atomic positions [nat * 3]
276+
/// @param zv_facts - precomputed zv factors for each atom [nat]
277+
///
278+
/// Output Parameters
279+
/// @param rhostar_real - real part of structure factor [npw]
280+
/// @param rhostar_imag - imaginary part of structure factor [npw]
281+
void operator()(const Device* ctx,
282+
const int& nat,
283+
const int& npw,
284+
const int& ig_gge0,
285+
const FPTYPE* gcar,
286+
const FPTYPE* tau,
287+
const FPTYPE* zv_facts,
288+
FPTYPE* rhostar_real,
289+
FPTYPE* rhostar_imag);
290+
};
291+
263292

264293
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
265294
template <typename FPTYPE>
@@ -434,6 +463,20 @@ struct cal_multi_dot_op<FPTYPE, base_device::DEVICE_GPU>{
434463
const std::complex<FPTYPE>* psi);
435464
};
436465

466+
template <typename FPTYPE>
467+
struct cal_stress_ewa_sincos_op<FPTYPE, base_device::DEVICE_GPU>
468+
{
469+
void operator()(const base_device::DEVICE_GPU* ctx,
470+
const int& nat,
471+
const int& npw,
472+
const int& ig_gge0,
473+
const FPTYPE* gcar,
474+
const FPTYPE* tau,
475+
const FPTYPE* zv_facts,
476+
FPTYPE* rhostar_real,
477+
FPTYPE* rhostar_imag);
478+
};
479+
437480
/**
438481
* The operator is used to compute the auxiliary amount of stress /force
439482
* in parallel on the GPU. They identify type with the type provided and

source/module_hamilt_pw/hamilt_pwdft/stress_func_ewa.cpp

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "module_base/tool_threading.h"
55
#include "module_base/libm/libm.h"
66
#include "module_hamilt_pw/hamilt_pwdft/global.h"
7+
#include "kernels/stress_op.h"
78

89
#ifdef _OPENMP
910
#include <omp.h>
@@ -56,6 +57,45 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
5657
if (PARAM.globalv.gamma_only_pw && is_pw) fact=2.0;
5758
// else fact=1.0;
5859

60+
// Prepare data for the sincos op
61+
std::vector<FPTYPE> zv_facts_host(ucell.nat);
62+
std::vector<FPTYPE> tau_flat(ucell.nat * 3);
63+
64+
for (int iat = 0; iat < ucell.nat; iat++) {
65+
int it = ucell.iat2it[iat];
66+
int ia = ucell.iat2ia[iat];
67+
68+
zv_facts_host[iat] = static_cast<FPTYPE>(ucell.atoms[it].ncpp.zv);
69+
70+
tau_flat[iat * 3 + 0] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][0]);
71+
tau_flat[iat * 3 + 1] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][1]);
72+
tau_flat[iat * 3 + 2] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][2]);
73+
}
74+
75+
std::vector<FPTYPE> gcar_flat(rho_basis->npw * 3);
76+
for (int ig = 0; ig < rho_basis->npw; ig++) {
77+
gcar_flat[ig * 3 + 0] = static_cast<FPTYPE>(rho_basis->gcar[ig][0]);
78+
gcar_flat[ig * 3 + 1] = static_cast<FPTYPE>(rho_basis->gcar[ig][1]);
79+
gcar_flat[ig * 3 + 2] = static_cast<FPTYPE>(rho_basis->gcar[ig][2]);
80+
}
81+
82+
// Allocate result arrays
83+
std::vector<FPTYPE> rhostar_real_host(rho_basis->npw);
84+
std::vector<FPTYPE> rhostar_imag_host(rho_basis->npw);
85+
86+
// Call sincos op (outside OpenMP parallel region, op has its own parallelization)
87+
hamilt::cal_stress_ewa_sincos_op<FPTYPE, Device>()(
88+
this->ctx,
89+
ucell.nat,
90+
rho_basis->npw,
91+
rho_basis->ig_gge0,
92+
gcar_flat.data(),
93+
tau_flat.data(),
94+
zv_facts_host.data(),
95+
rhostar_real_host.data(),
96+
rhostar_imag_host.data()
97+
);
98+
5999
#ifdef _OPENMP
60100
#pragma omp parallel
61101
{
@@ -76,27 +116,21 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
76116
ig_end = ig + ig_end;
77117

78118
FPTYPE g2,g2a;
79-
FPTYPE arg;
80-
std::complex<FPTYPE> rhostar;
81119
FPTYPE sewald;
82120
for(; ig < ig_end; ig++)
83121
{
84122
if(ig == ig0) continue;
85123
g2 = rho_basis->gg[ig]* ucell.tpiba2;
86124
g2a = g2 /4.0/alpha;
87-
rhostar=std::complex<FPTYPE>(0.0,0.0);
88-
for(int it=0; it < ucell.ntype; it++)
89-
{
90-
for(int i=0; i<ucell.atoms[it].na; i++)
91-
{
92-
arg = (rho_basis->gcar[ig] * ucell.atoms[it].tau[i]) * (ModuleBase::TWO_PI);
93-
FPTYPE sinp, cosp;
94-
ModuleBase::libm::sincos(arg, &sinp, &cosp);
95-
rhostar = rhostar + std::complex<FPTYPE>(ucell.atoms[it].ncpp.zv * cosp,ucell.atoms[it].ncpp.zv * sinp);
96-
}
97-
}
98-
rhostar /= ucell.omega;
99-
sewald = fact* (ModuleBase::TWO_PI) * ModuleBase::e2 * ModuleBase::libm::exp(-g2a) / g2 * pow(std::abs(rhostar),2);
125+
126+
// Use precomputed rhostar values
127+
FPTYPE rhostar_real = rhostar_real_host[ig] / ucell.omega;
128+
FPTYPE rhostar_imag = rhostar_imag_host[ig] / ucell.omega;
129+
130+
// Calculate |rhostar|² - mathematically equivalent to pow(std::abs(rhostar), 2)
131+
FPTYPE rhostar_abs2 = rhostar_real * rhostar_real + rhostar_imag * rhostar_imag;
132+
133+
sewald = fact* (ModuleBase::TWO_PI) * ModuleBase::e2 * ModuleBase::libm::exp(-g2a) / g2 * rhostar_abs2;
100134
local_sdewald -= sewald;
101135
for(int l=0;l<3;l++)
102136
{

0 commit comments

Comments
 (0)