Skip to content

Commit 73cd77f

Browse files
authored
Merge pull request #103 from amcadmus/devel
fix bug in dp test
2 parents 7971a89 + 4f2576e commit 73cd77f

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

source/lmp/pair_nnp.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <iostream>
22
#include <iomanip>
3+
#include <limits>
34
#include "atom.h"
45
#include "domain.h"
56
#include "comm.h"
@@ -86,7 +87,7 @@ ana_st (double & max,
8687
const vector<double> & vec,
8788
const int & nloc)
8889
{
89-
if (vec.size() == 0) return;
90+
if (nloc == 0) return;
9091
max = vec[0];
9192
min = vec[0];
9293
sum = vec[0];
@@ -438,7 +439,7 @@ void PairNNP::compute(int eflag, int vflag)
438439
std_f.resize(std_f_.size());
439440
for (int dd = 0; dd < std_f_.size(); ++dd) std_f[dd] = std_f_[dd];
440441
#endif
441-
double min = 0, max = 0, avg = 0;
442+
double min = numeric_limits<double>::max(), max = 0, avg = 0;
442443
ana_st(max, min, avg, std_f, nlocal);
443444
int all_nlocal = 0;
444445
MPI_Reduce (&nlocal, &all_nlocal, 1, MPI_INT, MPI_SUM, 0, world);
@@ -460,7 +461,8 @@ void PairNNP::compute(int eflag, int vflag)
460461
std_e.resize(std_e_.size());
461462
for (int dd = 0; dd < std_e_.size(); ++dd) std_e[dd] = std_e_[dd];
462463
#endif
463-
min = max = avg = 0;
464+
max = avg = 0;
465+
min = numeric_limits<double>::max();
464466
ana_st(max, min, avg, std_e, nlocal);
465467
double all_e_min = 0, all_e_max = 0, all_e_avg = 0;
466468
MPI_Reduce (&min, &all_e_min, 1, MPI_DOUBLE, MPI_MIN, 0, world);

source/train/Data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,14 @@ def __init__ (self,
359359
self.has_fparam = 0
360360
else :
361361
self.has_fparam = -1
362+
# check aparam
363+
has_aparam = [ os.path.isfile(os.path.join(ii, 'aparam.npy')) for ii in self.dirs ]
364+
if any(has_aparam) and (not all(has_aparam)) :
365+
raise RuntimeError("system %s: if any set has frame parameter, then all sets should have frame parameter" % sys_path)
366+
if all(has_aparam) :
367+
self.has_aparam = 0
368+
else :
369+
self.has_aparam = -1
362370
# energy norm
363371
self.eavg = self.stats_energy()
364372
# load sets
@@ -463,6 +471,12 @@ def load_set(self, set_name, shuffle = True):
463471
self.has_fparam = data["fparam"].shape[1]
464472
else :
465473
assert self.has_fparam == data["fparam"].shape[1]
474+
if self.has_aparam >= 0:
475+
data["aparam"] = self.load_data(set_name, "aparam", [nframe, -1])
476+
if self.has_aparam == 0 :
477+
self.has_aparam = data["aparam"].shape[1] // (ncoord//3)
478+
else :
479+
assert self.has_aparam == data["aparam"].shape[1] // (ncoord//3)
466480
data["prop_c"] = np.zeros(5)
467481
data["prop_c"][0], data["energy"], data["prop_c"][3], data["atom_ener"] \
468482
= self.load_energy (set_name, nframe, ncoord // 3, "energy", "atom_ener")
@@ -573,3 +587,6 @@ def get_ener (self) :
573587
def numb_fparam(self) :
574588
return self.has_fparam
575589

590+
def numb_aparam(self) :
591+
return self.has_aparam
592+

source/train/test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,15 @@ def test_ener (args) :
4242
coord = test_data["coord"][:numb_test].reshape([numb_test, -1])
4343
box = test_data["box"][:numb_test]
4444
atype = test_data["type"][0]
45-
energy, force, virial, ae, av = dp.eval(coord, box, atype, fparam = (test_data["fparam"] if "fparam" in test_data else None), atomic = True)
45+
if dp.get_dim_fparam() > 0:
46+
fparam = test_data["fparam"][:numb_test]
47+
else :
48+
fparam = None
49+
if dp.get_dim_aparam() > 0:
50+
aparam = test_data["aparam"][:numb_test]
51+
else :
52+
aparam = None
53+
energy, force, virial, ae, av = dp.eval(coord, box, atype, fparam = fparam, aparam = aparam, atomic = True)
4654
energy = energy.reshape([numb_test,1])
4755
force = force.reshape([numb_test,-1])
4856
virial = virial.reshape([numb_test,9])

0 commit comments

Comments
 (0)