Skip to content

Commit cbb3522

Browse files
committed
CPU expol gradient implementation
1 parent 6ed07f8 commit cbb3522

File tree

3 files changed

+31
-25
lines changed

3 files changed

+31
-25
lines changed

include/ff/hippo/expol.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace tinker {
66
void expolData(RcOp);
77

88
void alterpol(real (*polscale)[3][3], real (*polinv)[3][3]);
9-
void dexpol(const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec* depz,
9+
void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec* depz,
1010
VirialBuffer restrict vir_ep);
1111

1212
enum class ExpolScr

src/acc/hippo/alterpol.cpp

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ void alterpol(real (*polscale)[3][3], real (*polinv)[3][3])
105105
if (r2 <= off2 and incl1 and incl2) {
106106
real r = REAL_SQRT(r2);
107107
real ks2i[3][3], ks2k[3][3];
108-
pair_alterpol(scrtyp, r, r2, 1, cut, off, xr, yr, zr, springi, sizi, alphai, springk, sizk,
108+
pair_alterpol(scrtyp, r, r2, dscale, cut, off, xr, yr, zr, springi, sizi, alphai, springk, sizk,
109109
alphak, ks2i, ks2k);
110110
#pragma acc loop seq
111111
for (int l = 0; l < 3; ++l) {
@@ -139,9 +139,10 @@ void alterpol(real (*polscale)[3][3], real (*polinv)[3][3])
139139
}
140140
}
141141

142-
void dexpol(const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec* depz,
142+
void dexpol(const int vers, const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec* depz,
143143
VirialBuffer restrict vir_ep)
144144
{
145+
auto do_v = vers & calc::virial;
145146
real cut = switchCut(Switch::REPULS);
146147
real off = switchOff(Switch::REPULS);
147148

@@ -157,7 +158,7 @@ void dexpol(const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec*
157158
real xi = x[i];
158159
real yi = y[i];
159160
real zi = z[i];
160-
real springi = kpep[i];
161+
real springi = kpep[i]/polarity[i];
161162
real sizi = prepep[i];
162163
real alphai = dmppep[i];
163164
int epli = lpep[i];
@@ -180,7 +181,7 @@ void dexpol(const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec*
180181
bool incl = (epli || eplk);
181182
if (r2 <= off2 and incl) {
182183
real r = REAL_SQRT(r2);
183-
real springk = kpep[k];
184+
real springk = kpep[k]/polarity[k];
184185
real sizk = prepep[k];
185186
real alphak = dmppep[k];
186187
real ukx = uind[k][0];
@@ -189,22 +190,22 @@ void dexpol(const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec*
189190
real frc[3];
190191
pair_dexpol(scrtyp, r, r2, 1, cut, off, xr, yr, zr, uix, uiy, uiz, ukx, uky, ukz,
191192
springi, sizi, alphai, springk, sizk, alphak, f, frc);
192-
193193
gxi += frc[0];
194194
gyi += frc[1];
195195
gzi += frc[2];
196196
atomic_add(-frc[0], depx, k);
197197
atomic_add(-frc[1], depy, k);
198198
atomic_add(-frc[2], depz, k);
199199

200-
// // add "if CONSTEXPR (do_v)"
201-
// real vxx = -xr * frc[0];
202-
// real vxy = -0.5f * (yr * frc[0] + xr * frc[1]);
203-
// real vxz = -0.5f * (zr * frc[0] + xr * frc[2]);
204-
// real vyy = -yr * frc[1];
205-
// real vyz = -0.5f * (zr * frc[1] + yr * frc[2]);
206-
// real vzz = -zr * frc[2];
207-
// atomic_add(vxx, vxy, vxz, vyy, vyz, vzz, vir_ep, offset);
200+
if (do_v) {
201+
real vxx = -xr * frc[0];
202+
real vxy = -0.5f * (yr * frc[0] + xr * frc[1]);
203+
real vxz = -0.5f * (zr * frc[0] + xr * frc[2]);
204+
real vyy = -yr * frc[1];
205+
real vyz = -0.5f * (zr * frc[1] + yr * frc[2]);
206+
real vzz = -zr * frc[2];
207+
atomic_add(vxx, vxy, vxz, vyy, vyz, vzz, vir_ep, offset);
208+
}
208209
}
209210
}
210211
atomic_add(gxi, depx, i);
@@ -244,7 +245,7 @@ void dexpol(const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec*
244245
if (r2 <= off2 and incl1 and incl2) {
245246
real r = REAL_SQRT(r2);
246247
real frc[3];
247-
pair_dexpol(scrtyp, r, r2, 1, cut, off, xr, yr, zr, uix, uiy, uiz, ukx, uky, ukz, springi,
248+
pair_dexpol(scrtyp, r, r2, dscale, cut, off, xr, yr, zr, uix, uiy, uiz, ukx, uky, ukz, springi,
248249
sizi, alphai, springk, sizk, alphak, f, frc);
249250

250251
atomic_add(frc[0], depx, i);
@@ -253,15 +254,16 @@ void dexpol(const real (*uind)[3], grad_prec* depx, grad_prec* depy, grad_prec*
253254
atomic_add(-frc[0], depx, k);
254255
atomic_add(-frc[1], depy, k);
255256
atomic_add(-frc[2], depz, k);
256-
257-
// // add "if CONSTEXPR (do_v)"
258-
// real vxx = -xr * frc[0];
259-
// real vxy = -0.5f * (yr * frc[0] + xr * frc[1]);
260-
// real vxz = -0.5f * (zr * frc[0] + xr * frc[2]);
261-
// real vyy = -yr * frc[1];
262-
// real vyz = -0.5f * (zr * frc[1] + yr * frc[2]);
263-
// real vzz = -zr * frc[2];
264-
// atomic_add(vxx, vxy, vxz, vyy, vyz, vzz, vir_ep, offset);
257+
258+
if (do_v) {
259+
real vxx = -xr * frc[0];
260+
real vxy = -0.5f * (yr * frc[0] + xr * frc[1]);
261+
real vxz = -0.5f * (zr * frc[0] + xr * frc[2]);
262+
real vyy = -yr * frc[1];
263+
real vyz = -0.5f * (zr * frc[1] + yr * frc[2]);
264+
real vzz = -zr * frc[2];
265+
atomic_add(vxx, vxy, vxz, vyy, vyz, vzz, vir_ep, offset);
266+
}
265267
}
266268
}
267269
}

src/hippo/epolar.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,11 @@ void epolarChgpen(int vers)
238238
if (use_cfgrad)
239239
dcflux(vers, depx, depy, depz, vir_ep);
240240
if (polpot::use_expol)
241-
dexpol(uind, depx, depy, depz, vir_ep);
241+
{
242+
if (do_g || do_v) {
243+
dexpol(vers, uind, depx, depy, depz, vir_ep);
244+
}
245+
}
242246
if (do_v) {
243247
VirialBuffer u2 = vir_trq;
244248
virial_prec v2[9];

0 commit comments

Comments
 (0)