Skip to content

Commit 9528a5a

Browse files
committed
fix bug of std:: string incompatible with TensorFlow 2.3
1 parent b97c00b commit 9528a5a

File tree

4 files changed

+11
-3
lines changed

4 files changed

+11
-3
lines changed

source/lib/include/common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
#pragma once
22

33
#include "tensorflow/core/public/session.h"
4+
#include "tensorflow/core/public/version.h"
45
#include "tensorflow/core/platform/env.h"
56
#include "tensorflow/core/framework/op.h"
67
#include "tensorflow/core/framework/op_kernel.h"
78
#include "tensorflow/core/framework/shape_inference.h"
9+
#include <string>
810

911
using namespace tensorflow;
1012
using namespace std;
1113

14+
#if TF_MAJOR_VERSION >= 2 && TF_MINOR_VERSION >= 2
15+
typedef tensorflow::tstring STRINGTYPE;
16+
#else
17+
typedef std::string STRINGTYPE;
18+
#endif
19+
1220
#include "NNPAtomMap.h"
1321
#include <vector>
1422
#include "version.h"

source/lib/src/DataModifier.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ init (const string & model,
3737
rcut = get_scalar<VALUETYPE>("descrpt_attr/rcut");
3838
cell_size = rcut;
3939
ntypes = get_scalar<int>("descrpt_attr/ntypes");
40-
model_type = get_scalar<string>("model_attr/model_type");
40+
model_type = get_scalar<STRINGTYPE>("model_attr/model_type");
4141
get_vector<int>(sel_type, "model_attr/sel_type");
4242
sort(sel_type.begin(), sel_type.end());
4343
inited = true;

source/lib/src/DeepTensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ init (const string & model,
3333
rcut = get_scalar<VALUETYPE>("descrpt_attr/rcut");
3434
cell_size = rcut;
3535
ntypes = get_scalar<int>("descrpt_attr/ntypes");
36-
model_type = get_scalar<string>("model_attr/model_type");
36+
model_type = get_scalar<STRINGTYPE>("model_attr/model_type");
3737
odim = get_scalar<int>("model_attr/output_dim");
3838
get_vector<int>(sel_type, "model_attr/sel_type");
3939
inited = true;

source/lib/src/NNPInter.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ compute (ENERGYTYPE & dener,
639639
void
640640
NNPInter::
641641
get_type_map(std::string & type_map){
642-
type_map = get_scalar<std::string>("model_attr/tmap");
642+
type_map = get_scalar<STRINGTYPE>("model_attr/tmap");
643643
}
644644

645645

0 commit comments

Comments
 (0)