Skip to content

Commit 243ebcf

Browse files
authored
Merge pull request #158 from amcadmus/devel
fix bug of test data
2 parents be45422 + a2ac409 commit 243ebcf

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ jobs:
9090
after_success:
9191
- if [[ $TRAVIS_TAG ]]; then python -m twine upload wheelhouse/*; python -m twine upload dist/*.tar.gz; fi
9292
before_install:
93-
- pip install --upgrade pip
93+
#- pip install --upgrade pip
9494
- pip install --upgrade setuptools
9595
- pip install tensorflow==$TENSORFLOW_VERSION
9696
install:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
tf_install_dir = imp.find_module('tensorflow', [site_packages_path])[1]
2121

2222
install_requires=['numpy', 'scipy']
23-
setup_requires=['setuptools_scm']
23+
setup_requires=['setuptools_scm', 'scikit-build', 'cmake']
2424

2525
# add cmake as a build requirement if cmake>3.0 is not installed
2626
try:

source/train/test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import argparse
77
import numpy as np
88

9-
from deepmd.Data import DataSets
109
from deepmd.Data import DeepmdData
1110
from deepmd import DeepEval
1211
from deepmd import DeepPot
@@ -35,13 +34,15 @@ def test_ener (args) :
3534
if args.rand_seed is not None :
3635
np.random.seed(args.rand_seed % (2**32))
3736

38-
data = DataSets (args.system, args.set_prefix, shuffle_test = args.shuffle_test)
37+
dp = DeepPot(args.model)
38+
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test, type_map = dp.get_type_map())
39+
3940
test_data = data.get_test ()
4041
numb_test = args.numb_test
4142
natoms = len(test_data["type"][0])
4243
nframes = test_data["box"].shape[0]
4344
numb_test = min(nframes, numb_test)
44-
dp = DeepPot(args.model)
45+
4546
coord = test_data["coord"][:numb_test].reshape([numb_test, -1])
4647
box = test_data["box"][:numb_test]
4748
atype = test_data["type"][0]

0 commit comments

Comments
 (0)