Skip to content

Commit 37b8c72

Browse files
committed
exchpol openacc gradient and virial implementation
1 parent cbb3522 commit 37b8c72

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

include/seq/pair_alterpol.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ inline void pair_alterpol(ExpolScr scrtyp, real r, real r2, real pscale, real cu
4747
}
4848
}
4949

50+
#pragma acc routine seq
51+
SEQ_CUDA
5052
inline void pair_dexpol(ExpolScr scrtyp, real r, real r2, real pscale, real cut, real off, real xr,
5153
real yr, real zr, real uix, real uiy, real uiz, real ukx, real uky, real ukz, real springi,
5254
real sizi, real alphai, real springk, real sizk, real alphak, const real f, real frc[3])
@@ -62,8 +64,8 @@ inline void pair_dexpol(ExpolScr scrtyp, real r, real r2, real pscale, real cut,
6264
if (r2 > cut2) {
6365
real taper, dtaper;
6466
switchTaper5<1>(r, cut, off, taper, dtaper);
65-
s2 = s2 * taper;
6667
ds2 = ds2 * taper + s2 * dtaper;
68+
s2 = s2 * taper;
6769
}
6870
real s2i = springi * s2 * pscale;
6971
real s2k = springk * s2 * pscale;

src/acc/hippo/alterpol.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ void alterpol(real (*polscale)[3][3], real (*polinv)[3][3])
105105
if (r2 <= off2 and incl1 and incl2) {
106106
real r = REAL_SQRT(r2);
107107
real ks2i[3][3], ks2k[3][3];
108-
pair_alterpol(scrtyp, r, r2, dscale, cut, off, xr, yr, zr, springi, sizi, alphai, springk, sizk,
109-
alphak, ks2i, ks2k);
108+
pair_alterpol(scrtyp, r, r2, dscale, cut, off, xr, yr, zr, springi, sizi, alphai, springk,
109+
sizk, alphak, ks2i, ks2k);
110110
#pragma acc loop seq
111111
for (int l = 0; l < 3; ++l) {
112112
#pragma acc loop seq
@@ -139,8 +139,8 @@ void alterpol(real (*polscale)[3][3], real (*polinv)[3][3])
139139
}
140140
}
141141

142-
void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec* depz,
143-
VirialBuffer restrict vir_ep)
142+
void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* depy,
143+
grad_prec* depz, VirialBuffer restrict vir_ep)
144144
{
145145
auto do_v = vers & calc::virial;
146146
real cut = switchCut(Switch::REPULS);
@@ -154,11 +154,15 @@ void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* d
154154

155155
const real f = 0.5f * electric / dielec;
156156

157+
MAYBE_UNUSED int GRID_DIM = gpuGridSize(BLOCK_DIM);
158+
#pragma acc parallel async num_gangs(GRID_DIM) vector_length(BLOCK_DIM)\
159+
deviceptr(x,y,z,polarity,kpep,prepep,dmppep,lpep,uind,depx,depy,depz,vir_ep,mlst,polscale)
160+
#pragma acc loop gang independent
157161
for (int i = 0; i < n; ++i) {
158162
real xi = x[i];
159163
real yi = y[i];
160164
real zi = z[i];
161-
real springi = kpep[i]/polarity[i];
165+
real springi = kpep[i] / polarity[i];
162166
real sizi = prepep[i];
163167
real alphai = dmppep[i];
164168
int epli = lpep[i];
@@ -170,6 +174,7 @@ void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* d
170174

171175
int nmlsti = mlst->nlst[i];
172176
int base = i * maxnlst;
177+
#pragma acc loop vector independent
173178
for (int kk = 0; kk < nmlsti; ++kk) {
174179
int offset = kk & (bufsize - 1);
175180
int k = mlst->lst[base + kk];
@@ -181,7 +186,7 @@ void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* d
181186
bool incl = (epli || eplk);
182187
if (r2 <= off2 and incl) {
183188
real r = REAL_SQRT(r2);
184-
real springk = kpep[k]/polarity[k];
189+
real springk = kpep[k] / polarity[k];
185190
real sizk = prepep[k];
186191
real alphak = dmppep[k];
187192
real ukx = uind[k][0];
@@ -213,6 +218,9 @@ void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* d
213218
atomic_add(gzi, depz, i);
214219
}
215220

221+
#pragma acc parallel loop independent async\
222+
deviceptr(x,y,z,polarity,kpep,prepep,dmppep,lpep,uind,depx,depy,depz,\
223+
vir_ep,mlst,mdwexclude,mdwexclude_scale,polscale)
216224
for (int ii = 0; ii < nmdwexclude; ++ii) {
217225
int offset = ii & (bufsize - 1);
218226

@@ -245,16 +253,16 @@ void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* d
245253
if (r2 <= off2 and incl1 and incl2) {
246254
real r = REAL_SQRT(r2);
247255
real frc[3];
248-
pair_dexpol(scrtyp, r, r2, dscale, cut, off, xr, yr, zr, uix, uiy, uiz, ukx, uky, ukz, springi,
249-
sizi, alphai, springk, sizk, alphak, f, frc);
256+
pair_dexpol(scrtyp, r, r2, dscale, cut, off, xr, yr, zr, uix, uiy, uiz, ukx, uky, ukz,
257+
springi, sizi, alphai, springk, sizk, alphak, f, frc);
250258

251259
atomic_add(frc[0], depx, i);
252260
atomic_add(frc[1], depy, i);
253261
atomic_add(frc[2], depz, i);
254262
atomic_add(-frc[0], depx, k);
255263
atomic_add(-frc[1], depy, k);
256264
atomic_add(-frc[2], depz, k);
257-
265+
258266
if (do_v) {
259267
real vxx = -xr * frc[0];
260268
real vxy = -0.5f * (yr * frc[0] + xr * frc[1]);

0 commit comments

Comments
 (0)