Skip to content

Commit 75bddf9

Browse files
author
Han Wang
committed
fix bug in dp test
1 parent 7971a89 commit 75bddf9

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

source/train/Data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,14 @@ def __init__ (self,
359359
self.has_fparam = 0
360360
else :
361361
self.has_fparam = -1
362+
# check aparam
363+
has_aparam = [ os.path.isfile(os.path.join(ii, 'aparam.npy')) for ii in self.dirs ]
364+
if any(has_aparam) and (not all(has_aparam)) :
365+
raise RuntimeError("system %s: if any set has frame parameter, then all sets should have frame parameter" % sys_path)
366+
if all(has_aparam) :
367+
self.has_aparam = 0
368+
else :
369+
self.has_aparam = -1
362370
# energy norm
363371
self.eavg = self.stats_energy()
364372
# load sets
@@ -463,6 +471,12 @@ def load_set(self, set_name, shuffle = True):
463471
self.has_fparam = data["fparam"].shape[1]
464472
else :
465473
assert self.has_fparam == data["fparam"].shape[1]
474+
if self.has_aparam >= 0:
475+
data["aparam"] = self.load_data(set_name, "aparam", [nframe, -1])
476+
if self.has_aparam == 0 :
477+
self.has_aparam = data["aparam"].shape[1] // (ncoord//3)
478+
else :
479+
assert self.has_aparam == data["aparam"].shape[1] // (ncoord//3)
466480
data["prop_c"] = np.zeros(5)
467481
data["prop_c"][0], data["energy"], data["prop_c"][3], data["atom_ener"] \
468482
= self.load_energy (set_name, nframe, ncoord // 3, "energy", "atom_ener")
@@ -573,3 +587,6 @@ def get_ener (self) :
573587
def numb_fparam(self) :
574588
return self.has_fparam
575589

590+
def numb_aparam(self) :
591+
return self.has_aparam
592+

source/train/test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,15 @@ def test_ener (args) :
4242
coord = test_data["coord"][:numb_test].reshape([numb_test, -1])
4343
box = test_data["box"][:numb_test]
4444
atype = test_data["type"][0]
45-
energy, force, virial, ae, av = dp.eval(coord, box, atype, fparam = (test_data["fparam"] if "fparam" in test_data else None), atomic = True)
45+
if dp.get_dim_fparam() > 0:
46+
fparam = test_data["fparam"][:numb_test]
47+
else :
48+
fparam = None
49+
if dp.get_dim_aparam() > 0:
50+
aparam = test_data["aparam"][:numb_test]
51+
else :
52+
aparam = None
53+
energy, force, virial, ae, av = dp.eval(coord, box, atype, fparam = fparam, aparam = aparam, atomic = True)
4654
energy = energy.reshape([numb_test,1])
4755
force = force.reshape([numb_test,-1])
4856
virial = virial.reshape([numb_test,9])

0 commit comments

Comments
 (0)