Skip to content

Commit 2ee503e

Browse files
committed
fix bug of thread safety in force and virial computation
1 parent 33b86c3 commit 2ee503e

File tree

2 files changed

+11
-24
lines changed

2 files changed

+11
-24
lines changed

source/op/prod_force.cc

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,13 @@ class ProdForceOp : public OpKernel {
9696
auto force = force_tensor->flat<VALUETYPE>();
9797

9898
// loop over samples
99-
int net_iter = 0;
100-
int in_iter = 0;
101-
int force_iter = 0;
102-
int nlist_iter = 0;
103-
int axis_iter = 0;
104-
10599
#pragma omp parallel for num_threads (num_threads)
106100
for (int kk = 0; kk < nframes; ++kk){
107-
force_iter = kk * nloc * 3;
108-
net_iter = kk * nloc * ndescrpt;
109-
in_iter = kk * nloc * ndescrpt * 12;
110-
nlist_iter = kk * nloc * nnei;
111-
axis_iter = kk * nloc * 4;
101+
int force_iter = kk * nloc * 3;
102+
int net_iter = kk * nloc * ndescrpt;
103+
int in_iter = kk * nloc * ndescrpt * 12;
104+
int nlist_iter = kk * nloc * nnei;
105+
int axis_iter = kk * nloc * 4;
112106

113107
for (int ii = 0; ii < nloc; ++ii){
114108
int i_idx = ii;

source/op/prod_virial.cc

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,14 @@ class ProdVirialOp : public OpKernel {
103103
auto virial = virial_tensor->flat<VALUETYPE>();
104104

105105
// loop over samples
106-
int net_iter = 0;
107-
int in_iter = 0;
108-
int rij_iter = 0;
109-
int nlist_iter = 0;
110-
int axis_iter = 0;
111-
int virial_iter = 0;
112-
113106
#pragma omp parallel for num_threads (num_threads)
114107
for (int kk = 0; kk < nframes; ++kk){
115-
net_iter = kk * nloc * ndescrpt;
116-
in_iter = kk * nloc * ndescrpt * 12;
117-
rij_iter = kk * nloc * nnei * 3;
118-
nlist_iter = kk * nloc * nnei;
119-
axis_iter = kk * nloc * 4;
120-
virial_iter = kk * 9;
108+
int net_iter = kk * nloc * ndescrpt;
109+
int in_iter = kk * nloc * ndescrpt * 12;
110+
int rij_iter = kk * nloc * nnei * 3;
111+
int nlist_iter = kk * nloc * nnei;
112+
int axis_iter = kk * nloc * 4;
113+
int virial_iter = kk * 9;
121114

122115
for (int ii = 0; ii < 9; ++ ii){
123116
virial (virial_iter + ii) = 0.;

0 commit comments

Comments
 (0)