Skip to content

Commit 0f9edb4

Browse files
authored
Merge pull request #159 from amcadmus/master
fix bug of test data, a few compiling bugs
2 parents bb8bf21 + 3b82e23 commit 0f9edb4

File tree

5 files changed

+14
-13
lines changed

5 files changed

+14
-13
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ matrix:
7272
- CXX=g++-8
7373
- TENSORFLOW_VERSION=2.0
7474
before_install:
75-
- pip install --upgrade pip
75+
# - pip install --upgrade pip
7676
- pip install --upgrade setuptools
7777
- pip install tensorflow==$TENSORFLOW_VERSION
7878
install:

source/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ if (BUILD_CPP_IF)
173173
if (USE_CUDA_TOOLKIT)
174174
set (LIB_DEEPMD_OP_CUDA "deepmd_op_cuda")
175175
else()
176-
set (LIB_DEEPMD_OP_CUDA "")
176+
set (LIB_DEEPMD_OP_CUDA "deepmd_op")
177177
endif()
178178
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 4.9)
179179
set (LIB_DEEPMD_NATIVE "deepmd_native_md")

source/lib/include/NNPInter.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ class NNPInter
112112
GraphDef graph_def;
113113
bool inited;
114114
template<class VT> VT get_scalar(const string & name) const;
115-
VALUETYPE get_rcut () const;
116-
int get_ntypes () const;
115+
// VALUETYPE get_rcut () const;
116+
// int get_ntypes () const;
117117
VALUETYPE rcut;
118118
VALUETYPE cell_size;
119119
int ntypes;
@@ -210,8 +210,8 @@ class NNPInterModelDevi
210210
vector<GraphDef> graph_defs;
211211
bool inited;
212212
template<class VT> VT get_scalar(const string name) const;
213-
VALUETYPE get_rcut () const;
214-
int get_ntypes () const;
213+
// VALUETYPE get_rcut () const;
214+
// int get_ntypes () const;
215215
VALUETYPE rcut;
216216
VALUETYPE cell_size;
217217
int ntypes;

source/lib/src/NNPInter.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,8 @@ init (const string & model, const int & gpu_rank)
849849
ntypes = get_scalar<int>("descrpt_attr/ntypes");
850850
dfparam = get_scalar<int>("fitting_attr/dfparam");
851851
daparam = get_scalar<int>("fitting_attr/daparam");
852-
assert(rcut == get_rcut());
853-
assert(ntypes == get_ntypes());
852+
// assert(rcut == get_rcut());
853+
// assert(ntypes == get_ntypes());
854854
if (dfparam < 0) dfparam = 0;
855855
if (daparam < 0) daparam = 0;
856856
inited = true;
@@ -880,8 +880,8 @@ init (const string & model, const int & gpu_rank)
880880
ntypes = get_scalar<int>("descrpt_attr/ntypes");
881881
dfparam = get_scalar<int>("fitting_attr/dfparam");
882882
daparam = get_scalar<int>("fitting_attr/daparam");
883-
assert(rcut == get_rcut());
884-
assert(ntypes == get_ntypes());
883+
// assert(rcut == get_rcut());
884+
// assert(ntypes == get_ntypes());
885885
if (dfparam < 0) dfparam = 0;
886886
if (daparam < 0) daparam = 0;
887887
// rcut = get_rcut();

source/train/test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import argparse
77
import numpy as np
88

9-
from deepmd.Data import DataSets
109
from deepmd.Data import DeepmdData
1110
from deepmd import DeepEval
1211
from deepmd import DeepPot
@@ -35,13 +34,15 @@ def test_ener (args) :
3534
if args.rand_seed is not None :
3635
np.random.seed(args.rand_seed % (2**32))
3736

38-
data = DataSets (args.system, args.set_prefix, shuffle_test = args.shuffle_test)
37+
dp = DeepPot(args.model)
38+
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test, type_map = dp.get_type_map())
39+
3940
test_data = data.get_test ()
4041
numb_test = args.numb_test
4142
natoms = len(test_data["type"][0])
4243
nframes = test_data["box"].shape[0]
4344
numb_test = min(nframes, numb_test)
44-
dp = DeepPot(args.model)
45+
4546
coord = test_data["coord"][:numb_test].reshape([numb_test, -1])
4647
box = test_data["box"][:numb_test]
4748
atype = test_data["type"][0]

0 commit comments

Comments
 (0)