Skip to content

Commit f4bd9a2

Browse files
committed
fix select_map for force and virial
1 parent b664ded commit f4bd9a2

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

source/api_cc/src/DeepTensorPT.cc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -367,13 +367,11 @@ void DeepTensorPT::compute(std::vector<VALUETYPE>& global_tensor,
367367
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
368368

369369
// bkw map for forces
370-
// Force tensor has shape [nframes, natoms, coord, dipole_components] = [1, 6,
371-
// 3, 3] We need to map it as [nframes, natoms, total_force_components]
372-
int force_components_per_atom = dforce.size() / nall_real;
373-
force.resize(static_cast<size_t>(nframes) * fwd_map.size() *
374-
force_components_per_atom);
375-
select_map<VALUETYPE>(force, dforce, bkw_map, force_components_per_atom,
376-
nframes, fwd_map.size(), nall_real);
370+
force.resize(static_cast<size_t>(nframes) * odim * fwd_map.size() * 3);
371+
for (int kk = 0; kk < odim; ++kk) {
372+
select_map<VALUETYPE>(force.begin() + kk * fwd_map.size() * 3,
373+
dforce.begin() + kk * bkw_map.size() * 3, bkw_map, 3);
374+
}
377375

378376
// Extract atomic dipoles/polars if available
379377
if (outputs.contains("dipole")) {
@@ -415,14 +413,13 @@ void DeepTensorPT::compute(std::vector<VALUETYPE>& global_tensor,
415413
datom_virial.assign(
416414
cpu_atom_virial_.data_ptr<VALUETYPE>(),
417415
cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
418-
// extended_virial shape is [nframes, natoms, task_dim, 9] so total
419-
// components is task_dim * 9
420-
int total_virial_components = datom_virial.size() / nall_real;
421-
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() *
422-
total_virial_components);
423-
select_map<VALUETYPE>(atom_virial, datom_virial, bkw_map,
424-
total_virial_components, nframes, fwd_map.size(),
425-
nall_real);
416+
atom_virial.resize(static_cast<size_t>(nframes) * odim * fwd_map.size() *
417+
9);
418+
for (int kk = 0; kk < odim; ++kk) {
419+
select_map<VALUETYPE>(atom_virial.begin() + kk * fwd_map.size() * 9,
420+
datom_virial.begin() + kk * bkw_map.size() * 9,
421+
bkw_map, 9);
422+
}
426423
}
427424
}
428425

0 commit comments

Comments
 (0)