Skip to content

Commit d7c4fc2

Browse files
authored
Merge pull request #369 from denghuilu/1.2
convert dp1.2 model to dp1.3 model
2 parents 46929a1 + 58da46c commit d7c4fc2

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

source/train/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
configure_file("RunOptions.py.in" "${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py" @ONLY)
44

5-
file(GLOB LIB_PY main.py common.py env.py compat.py calculator.py Network.py Deep*.py Data.py DataSystem.py Model*.py Descrpt*.py Fitting.py Loss.py LearningRate.py Trainer.py TabInter.py EwaldRecp.py DataModifier.py ${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py transform.py)
5+
file(GLOB LIB_PY main.py common.py env.py compat.py calculator.py Network.py Deep*.py Data.py DataSystem.py Model*.py Descrpt*.py Fitting.py Loss.py LearningRate.py Trainer.py TabInter.py EwaldRecp.py DataModifier.py ${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py transform.py convert_to_13.py)
66

77
file(GLOB CLS_PY Local.py Slurm.py)
88

source/train/convert_to_13.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from deepmd.env import tf
2+
from google.protobuf import text_format
3+
from tensorflow.python.platform import gfile
4+
from tensorflow.python import pywrap_tensorflow
5+
from tensorflow.python.framework import graph_util
6+
7+
def convert_to_13(args):
8+
convert_pb_to_pbtxt(args.input_model, 'frozen_model.pbtxt')
9+
convert_to_dp13('frozen_model.pbtxt')
10+
convert_pbtxt_to_pb('frozen_model.pbtxt', args.output_model)
11+
print("the converted output model(1.3 support) is saved in %s" % args.output_model)
12+
13+
def convert_pb_to_pbtxt(pbfile, pbtxtfile):
14+
with gfile.FastGFile(pbfile, 'rb') as f:
15+
graph_def = tf.GraphDef()
16+
graph_def.ParseFromString(f.read())
17+
tf.import_graph_def(graph_def, name='')
18+
tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True)
19+
20+
def convert_pbtxt_to_pb(pbtxtfile, pbfile):
21+
with tf.gfile.FastGFile(pbtxtfile, 'r') as f:
22+
graph_def = tf.GraphDef()
23+
file_content = f.read()
24+
# Merges the human-readable string in `file_content` into `graph_def`.
25+
text_format.Merge(file_content, graph_def)
26+
tf.train.write_graph(graph_def, './', pbfile, as_text=False)
27+
28+
def convert_to_dp13(file):
29+
file_data = ""
30+
with open(file, "r", encoding="utf-8") as f:
31+
ii = 0
32+
lines = f.readlines()
33+
while (ii < len(lines)):
34+
line = lines[ii]
35+
file_data += line
36+
ii+=1
37+
if 'name' in line and ('DescrptSeA' in line or 'ProdForceSeA' in line or 'ProdVirialSeA' in line):
38+
while not('attr' in lines[ii] and '{' in lines[ii]):
39+
file_data += lines[ii]
40+
ii+=1
41+
file_data += ' attr {\n'
42+
file_data += ' key: \"T\"\n'
43+
file_data += ' value {\n'
44+
file_data += ' type: DT_DOUBLE\n'
45+
file_data += ' }\n'
46+
file_data += ' }\n'
47+
with open(file, "w", encoding="utf-8") as f:
48+
f.write(file_data)

source/train/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .config import config
66
from .test import test
77
from .transform import transform
8+
from .convert_to_13 import convert_to_13
89

910
def main () :
1011
parser = argparse.ArgumentParser(
@@ -57,6 +58,11 @@ def main () :
5758
parser_tst.add_argument("-d", "--detail-file", type=str,
5859
help="The file containing details of energy force and virial accuracy")
5960

61+
parser_transform = subparsers.add_parser('convert-to-1.3', help='convert dp-1.2 model to dp-1.3 model')
62+
parser_transform.add_argument('-i', "--input-model", default = "frozen_model.pb", type=str,
63+
help = "the input dp-1.2 model")
64+
parser_transform.add_argument("-o","--output-model", default = "frozen_model_1.3.pb", type=str,
65+
help='the converted dp-1.3 model')
6066
args = parser.parse_args()
6167

6268
if args.command is None :
@@ -72,5 +78,7 @@ def main () :
7278
test(args)
7379
elif args.command == 'transform' :
7480
transform(args)
81+
elif args.command == 'convert-to-1.3' :
82+
convert_to_13(args)
7583
else :
7684
raise RuntimeError('unknown command ' + args.command)

0 commit comments

Comments
 (0)