Skip to content

Commit a959868

Browse files
author
Han Wang
committed
fix bug of test data
1 parent be45422 commit a959868

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

source/train/test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import argparse
77
import numpy as np
88

9-
from deepmd.Data import DataSets
109
from deepmd.Data import DeepmdData
1110
from deepmd import DeepEval
1211
from deepmd import DeepPot
@@ -35,13 +34,15 @@ def test_ener (args) :
3534
if args.rand_seed is not None :
3635
np.random.seed(args.rand_seed % (2**32))
3736

38-
data = DataSets (args.system, args.set_prefix, shuffle_test = args.shuffle_test)
37+
dp = DeepPot(args.model)
38+
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test, type_map = dp.get_type_map())
39+
3940
test_data = data.get_test ()
4041
numb_test = args.numb_test
4142
natoms = len(test_data["type"][0])
4243
nframes = test_data["box"].shape[0]
4344
numb_test = min(nframes, numb_test)
44-
dp = DeepPot(args.model)
45+
4546
coord = test_data["coord"][:numb_test].reshape([numb_test, -1])
4647
box = test_data["box"][:numb_test]
4748
atype = test_data["type"][0]

0 commit comments

Comments
 (0)