@@ -106,13 +106,45 @@ void PairNNP::compute(int eflag, int vflag)
106106 LammpsNeighborList lmp_list (list->inum , list->ilist , list->numneigh , list->firstneigh );
107107 if (numb_models == 1 ) {
108108 if ( ! (eflag_atom || vflag_atom) ) {
109+ #ifdef HIGH_PREC
109110 nnp_inter.compute (dener, dforce, dvirial, dcoord, dtype, dbox, nghost, lmp_list);
111+ #else
112+ vector<float > dcoord_ (dcoord.size ());
113+ vector<float > dbox_ (dbox.size ());
114+ for (unsigned dd = 0 ; dd < dcoord.size (); ++dd) dcoord_[dd] = dcoord[dd];
115+ for (unsigned dd = 0 ; dd < dbox.size (); ++dd) dbox_[dd] = dbox[dd];
116+ vector<float > dforce_ (dforce.size (), 0 );
117+ vector<float > dvirial_ (dvirial.size (), 0 );
118+ float dener_ = 0 ;
119+ nnp_inter.compute (dener_, dforce_, dvirial_, dcoord_, dtype, dbox_, nghost, lmp_list);
120+ for (unsigned dd = 0 ; dd < dforce.size (); ++dd) dforce[dd] = dforce_[dd];
121+ for (unsigned dd = 0 ; dd < dvirial.size (); ++dd) dvirial[dd] = dvirial_[dd];
122+ dener = dener_;
123+ #endif
110124 }
111125 // do atomic energy and virial
112126 else {
113127 vector<double > deatom (nall * 1 , 0 );
114128 vector<double > dvatom (nall * 9 , 0 );
129+ #ifdef HIGH_PREC
115130 nnp_inter.compute (dener, dforce, dvirial, deatom, dvatom, dcoord, dtype, dbox, nghost, lmp_list);
131+ #else
132+ vector<float > dcoord_ (dcoord.size ());
133+ vector<float > dbox_ (dbox.size ());
134+ for (unsigned dd = 0 ; dd < dcoord.size (); ++dd) dcoord_[dd] = dcoord[dd];
135+ for (unsigned dd = 0 ; dd < dbox.size (); ++dd) dbox_[dd] = dbox[dd];
136+ vector<float > dforce_ (dforce.size (), 0 );
137+ vector<float > dvirial_ (dvirial.size (), 0 );
138+ vector<float > deatom_ (dforce.size (), 0 );
139+ vector<float > dvatom_ (dforce.size (), 0 );
140+ float dener_ = 0 ;
141+ nnp_inter.compute (dener_, dforce_, dvirial_, deatom_, dvatom_, dcoord_, dtype, dbox_, nghost, lmp_list);
142+ for (unsigned dd = 0 ; dd < dforce.size (); ++dd) dforce[dd] = dforce_[dd];
143+ for (unsigned dd = 0 ; dd < dvirial.size (); ++dd) dvirial[dd] = dvirial_[dd];
144+ for (unsigned dd = 0 ; dd < deatom.size (); ++dd) deatom[dd] = deatom_[dd];
145+ for (unsigned dd = 0 ; dd < dvatom.size (); ++dd) dvatom[dd] = dvatom_[dd];
146+ dener = dener_;
147+ #endif
116148 if (eflag_atom) {
117149 for (int ii = 0 ; ii < nlocal; ++ii) eatom[ii] += deatom[ii];
118150 }
@@ -129,6 +161,9 @@ void PairNNP::compute(int eflag, int vflag)
129161 }
130162 }
131163 else {
164+ vector<double > deatom (nall * 1 , 0 );
165+ vector<double > dvatom (nall * 9 , 0 );
166+ #ifdef HIGH_PREC
132167 vector<double > all_energy;
133168 vector<vector<double >> all_virial;
134169 vector<vector<double >> all_atom_energy;
@@ -137,10 +172,42 @@ void PairNNP::compute(int eflag, int vflag)
137172 nnp_inter_model_devi.compute_avg (dener, all_energy);
138173 nnp_inter_model_devi.compute_avg (dforce, all_force);
139174 nnp_inter_model_devi.compute_avg (dvirial, all_virial);
140- vector<double > deatom (nall * 1 , 0 );
141- vector<double > dvatom (nall * 9 , 0 );
142175 nnp_inter_model_devi.compute_avg (deatom, all_atom_energy);
143176 nnp_inter_model_devi.compute_avg (dvatom, all_atom_virial);
177+ #else
178+ vector<float > dcoord_ (dcoord.size ());
179+ vector<float > dbox_ (dbox.size ());
180+ for (unsigned dd = 0 ; dd < dcoord.size (); ++dd) dcoord_[dd] = dcoord[dd];
181+ for (unsigned dd = 0 ; dd < dbox.size (); ++dd) dbox_[dd] = dbox[dd];
182+ vector<float > dforce_ (dforce.size (), 0 );
183+ vector<float > dvirial_ (dvirial.size (), 0 );
184+ vector<float > deatom_ (dforce.size (), 0 );
185+ vector<float > dvatom_ (dforce.size (), 0 );
186+ float dener_ = 0 ;
187+ vector<float > all_energy_;
188+ vector<vector<float >> all_force_;
189+ vector<vector<float >> all_virial_;
190+ vector<vector<float >> all_atom_energy_;
191+ vector<vector<float >> all_atom_virial_;
192+ nnp_inter_model_devi.compute (all_energy_, all_force_, all_virial_, all_atom_energy_, all_atom_virial_, dcoord_, dtype, dbox_, nghost, lmp_list);
193+ nnp_inter_model_devi.compute_avg (dener_, all_energy_);
194+ nnp_inter_model_devi.compute_avg (dforce_, all_force_);
195+ nnp_inter_model_devi.compute_avg (dvirial_, all_virial_);
196+ nnp_inter_model_devi.compute_avg (deatom_, all_atom_energy_);
197+ nnp_inter_model_devi.compute_avg (dvatom_, all_atom_virial_);
198+ dener = dener_;
199+ for (unsigned dd = 0 ; dd < dforce.size (); ++dd) dforce[dd] = dforce_[dd];
200+ for (unsigned dd = 0 ; dd < dvirial.size (); ++dd) dvirial[dd] = dvirial_[dd];
201+ for (unsigned dd = 0 ; dd < deatom.size (); ++dd) deatom[dd] = deatom_[dd];
202+ for (unsigned dd = 0 ; dd < dvatom.size (); ++dd) dvatom[dd] = dvatom_[dd];
203+ all_force.resize (all_force_.size ());
204+ for (unsigned ii = 0 ; ii < all_force_.size (); ++ii){
205+ all_force[ii].resize (all_force_[ii].size ());
206+ for (unsigned jj = 0 ; jj < all_force_[ii].size (); ++jj){
207+ all_force[ii][jj] = all_force_[ii][jj];
208+ }
209+ }
210+ #endif
144211 if (eflag_atom) {
145212 for (int ii = 0 ; ii < nlocal; ++ii) eatom[ii] += deatom[ii];
146213 }
@@ -160,10 +227,23 @@ void PairNNP::compute(int eflag, int vflag)
160227 if (newton_pair) {
161228 comm->reverse_comm_pair (this );
162229 }
163- vector<double > tmp_avg_f;
164230 vector<double > std_f;
231+ #ifdef HIGH_PREC
232+ vector<double > tmp_avg_f;
165233 nnp_inter_model_devi.compute_avg (tmp_avg_f, all_force);
166234 nnp_inter_model_devi.compute_std_f (std_f, tmp_avg_f, all_force);
235+ #else
236+ vector<float > tmp_avg_f_, std_f_;
237+ for (unsigned ii = 0 ; ii < all_force_.size (); ++ii){
238+ for (unsigned jj = 0 ; jj < all_force_[ii].size (); ++jj){
239+ all_force_[ii][jj] = all_force[ii][jj];
240+ }
241+ }
242+ nnp_inter_model_devi.compute_avg (tmp_avg_f_, all_force_);
243+ nnp_inter_model_devi.compute_std_f (std_f_, tmp_avg_f_, all_force_);
244+ std_f.resize (std_f_.size ());
245+ for (int dd = 0 ; dd < std_f_.size (); ++dd) std_f[dd] = std_f_[dd];
246+ #endif
167247 double min = 0 , max = 0 , avg = 0 ;
168248 ana_st (max, min, avg, std_f, nlocal);
169249 int all_nlocal = 0 ;
@@ -174,10 +254,18 @@ void PairNNP::compute(int eflag, int vflag)
174254 MPI_Reduce (&avg, &all_f_avg, 1 , MPI_DOUBLE, MPI_SUM, 0 , world);
175255 all_f_avg /= double (all_nlocal);
176256 // std energy
177- vector<double > tmp_avg_e;
178257 vector<double > std_e;
258+ #ifdef HIGH_PREC
259+ vector<double > tmp_avg_e;
179260 nnp_inter_model_devi.compute_avg (tmp_avg_e, all_atom_energy);
180261 nnp_inter_model_devi.compute_std_e (std_e, tmp_avg_e, all_atom_energy);
262+ #else
263+ vector<float > tmp_avg_e_, std_e_;
264+ nnp_inter_model_devi.compute_avg (tmp_avg_e_, all_atom_energy_);
265+ nnp_inter_model_devi.compute_std_e (std_e_, tmp_avg_e_, all_atom_energy_);
266+ std_e.resize (std_e_.size ());
267+ for (int dd = 0 ; dd < std_e_.size (); ++dd) std_e[dd] = std_e_[dd];
268+ #endif
181269 min = max = avg = 0 ;
182270 ana_st (max, min, avg, std_e, nlocal);
183271 double all_e_min = 0 , all_e_max = 0 , all_e_avg = 0 ;
@@ -209,7 +297,21 @@ void PairNNP::compute(int eflag, int vflag)
209297 }
210298 else {
211299 if (numb_models == 1 ) {
300+ #ifdef HIGH_PREC
212301 nnp_inter.compute (dener, dforce, dvirial, dcoord, dtype, dbox, nghost);
302+ #else
303+ vector<float > dcoord_ (dcoord.size ());
304+ vector<float > dbox_ (dbox.size ());
305+ for (unsigned dd = 0 ; dd < dcoord.size (); ++dd) dcoord_[dd] = dcoord[dd];
306+ for (unsigned dd = 0 ; dd < dbox.size (); ++dd) dbox_[dd] = dbox[dd];
307+ vector<float > dforce_ (dforce.size (), 0 );
308+ vector<float > dvirial_ (dvirial.size (), 0 );
309+ float dener_ = 0 ;
310+ nnp_inter.compute (dener_, dforce_, dvirial_, dcoord_, dtype, dbox_, nghost);
311+ for (unsigned dd = 0 ; dd < dforce.size (); ++dd) dforce[dd] = dforce_[dd];
312+ for (unsigned dd = 0 ; dd < dvirial.size (); ++dd) dvirial[dd] = dvirial_[dd];
313+ dener = dener_;
314+ #endif
213315 }
214316 else {
215317 error->all (FLERR," Serial version does not support model devi" );
0 commit comments