@@ -367,13 +367,11 @@ void DeepTensorPT::compute(std::vector<VALUETYPE>& global_tensor,
367367 cpu_virial_.data_ptr <VALUETYPE>() + cpu_virial_.numel ());
368368
369369 // bkw map for forces
370- // Force tensor has shape [nframes, natoms, coord, dipole_components] = [1, 6,
371- // 3, 3] We need to map it as [nframes, natoms, total_force_components]
372- int force_components_per_atom = dforce.size () / nall_real;
373- force.resize (static_cast <size_t >(nframes) * fwd_map.size () *
374- force_components_per_atom);
375- select_map<VALUETYPE>(force, dforce, bkw_map, force_components_per_atom,
376- nframes, fwd_map.size (), nall_real);
370+ force.resize (static_cast <size_t >(nframes) * odim * fwd_map.size () * 3 );
371+ for (int kk = 0 ; kk < odim; ++kk) {
372+ select_map<VALUETYPE>(force.begin () + kk * fwd_map.size () * 3 ,
373+ dforce.begin () + kk * bkw_map.size () * 3 , bkw_map, 3 );
374+ }
377375
378376 // Extract atomic dipoles/polars if available
379377 if (outputs.contains (" dipole" )) {
@@ -415,14 +413,13 @@ void DeepTensorPT::compute(std::vector<VALUETYPE>& global_tensor,
415413 datom_virial.assign (
416414 cpu_atom_virial_.data_ptr <VALUETYPE>(),
417415 cpu_atom_virial_.data_ptr <VALUETYPE>() + cpu_atom_virial_.numel ());
418- // extended_virial shape is [nframes, natoms, task_dim, 9] so total
419- // components is task_dim * 9
420- int total_virial_components = datom_virial.size () / nall_real;
421- atom_virial.resize (static_cast <size_t >(nframes) * fwd_map.size () *
422- total_virial_components);
423- select_map<VALUETYPE>(atom_virial, datom_virial, bkw_map,
424- total_virial_components, nframes, fwd_map.size (),
425- nall_real);
416+ atom_virial.resize (static_cast <size_t >(nframes) * odim * fwd_map.size () *
417+ 9 );
418+ for (int kk = 0 ; kk < odim; ++kk) {
419+ select_map<VALUETYPE>(atom_virial.begin () + kk * fwd_map.size () * 9 ,
420+ datom_virial.begin () + kk * bkw_map.size () * 9 ,
421+ bkw_map, 9 );
422+ }
426423 }
427424}
428425
0 commit comments