Skip to content

Commit c2c6476

Browse files
authored
refactor convert (#2854)
Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 9d774a4 commit c2c6476

File tree

2 files changed

+54
-76
lines changed

2 files changed

+54
-76
lines changed

deepmd/utils/convert.py

Lines changed: 47 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import logging
23
import os
34
import textwrap
5+
from typing import (
6+
Optional,
7+
)
48

59
from google.protobuf import (
610
text_format,
711
)
12+
from packaging.specifiers import (
13+
SpecifierSet,
14+
)
15+
from packaging.version import parse as parse_version
816

17+
from deepmd import (
18+
__version__,
19+
)
920
from deepmd.env import (
1021
tf,
1122
)
1223

24+
log = logging.getLogger(__name__)
25+
1326

1427
def detect_model_version(input_model: str):
1528
"""Detect DP graph version.
@@ -20,33 +33,33 @@ def detect_model_version(input_model: str):
2033
filename of the input graph
2134
"""
2235
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
23-
version = "undetected"
36+
version = None
2437
with open("frozen_model.pbtxt") as fp:
2538
file_content = fp.read()
2639
if file_content.find("DescrptNorot") > -1:
27-
version = "<= 0.12"
40+
version = parse_version("0.12")
2841
elif (
2942
file_content.find("fitting_attr/dfparam") > -1
3043
and file_content.find("fitting_attr/daparam") == -1
3144
):
32-
version = "1.0"
45+
version = parse_version("1.0")
3346
elif file_content.find("model_attr/model_version") == -1:
3447
name_dsea = file_content.find('name: "DescrptSeA"')
3548
post_dsea = file_content[name_dsea:]
3649
post_dsea2 = post_dsea[:300].find(r"}")
3750
search_double = post_dsea[:post_dsea2]
3851
if search_double.find("DT_DOUBLE") == -1:
39-
version = "1.2"
52+
version = parse_version("1.2")
4053
else:
41-
version = "1.3"
54+
version = parse_version("1.3")
4255
elif file_content.find('string_val: "1.0"') > -1:
43-
version = "2.0"
56+
version = parse_version("2.0")
4457
elif file_content.find('string_val: "1.1"') > -1:
45-
version = ">= 2.1"
58+
version = parse_version("2.1")
4659
return version
4760

4861

49-
def convert_to_21(input_model: str, output_model: str):
62+
def convert_to_21(input_model: str, output_model: str, version: Optional[str] = None):
5063
"""Convert DP graph to 2.1 graph.
5164
5265
Parameters
@@ -55,37 +68,36 @@ def convert_to_21(input_model: str, output_model: str):
5568
filename of the input graph
5669
output_model : str
5770
filename of the output graph
71+
version : str
72+
version of the input graph, if not specified, it will be detected automatically
5873
"""
59-
version = detect_model_version(input_model)
60-
if version == "<= 0.12":
74+
if version is None:
75+
version = detect_model_version(input_model)
76+
else:
77+
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
78+
if version is None:
79+
raise ValueError(
80+
"The version of the DP graph %s cannot be detected. Please do the conversion manually."
81+
% (input_model)
82+
)
83+
if version in SpecifierSet("<1.0"):
6184
convert_dp012_to_dp10("frozen_model.pbtxt")
85+
if version in SpecifierSet("<1.1"):
6286
convert_dp10_to_dp11("frozen_model.pbtxt")
87+
if version in SpecifierSet("<1.3"):
6388
convert_dp12_to_dp13("frozen_model.pbtxt")
89+
if version in SpecifierSet("<2.0"):
6490
convert_dp13_to_dp20("frozen_model.pbtxt")
91+
if version in SpecifierSet("<2.1"):
6592
convert_dp20_to_dp21("frozen_model.pbtxt")
66-
elif version == "1.0":
67-
convert_dp10_to_dp11("frozen_model.pbtxt")
68-
convert_dp12_to_dp13("frozen_model.pbtxt")
69-
convert_dp13_to_dp20("frozen_model.pbtxt")
70-
convert_dp20_to_dp21("frozen_model.pbtxt")
71-
elif version == "1.2":
72-
convert_dp12_to_dp13("frozen_model.pbtxt")
73-
convert_dp13_to_dp20("frozen_model.pbtxt")
74-
convert_dp20_to_dp21("frozen_model.pbtxt")
75-
elif version == "1.3":
76-
convert_dp13_to_dp20("frozen_model.pbtxt")
77-
convert_dp20_to_dp21("frozen_model.pbtxt")
78-
elif version == "2.0":
79-
convert_dp20_to_dp21("frozen_model.pbtxt")
80-
elif version == "undetected":
81-
raise ValueError(
82-
"The version of the DP graph %s cannot be detected. Please do the conversion manually."
83-
% (input_model)
84-
)
8593
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
8694
if os.path.isfile("frozen_model.pbtxt"):
8795
os.remove("frozen_model.pbtxt")
88-
print("the converted output model (2.1 support) is saved in %s" % output_model)
96+
log.info(
97+
"the converted output model (%s support) is saved in %s",
98+
__version__,
99+
output_model,
100+
)
89101

90102

91103
def convert_13_to_21(input_model: str, output_model: str):
@@ -98,13 +110,7 @@ def convert_13_to_21(input_model: str, output_model: str):
98110
output_model : str
99111
filename of the output graph
100112
"""
101-
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
102-
convert_dp13_to_dp20("frozen_model.pbtxt")
103-
convert_dp20_to_dp21("frozen_model.pbtxt")
104-
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
105-
if os.path.isfile("frozen_model.pbtxt"):
106-
os.remove("frozen_model.pbtxt")
107-
print("the converted output model (2.1 support) is saved in %s" % output_model)
113+
convert_to_21(input_model, output_model, version="1.3")
108114

109115

110116
def convert_12_to_21(input_model: str, output_model: str):
@@ -117,14 +123,7 @@ def convert_12_to_21(input_model: str, output_model: str):
117123
output_model : str
118124
filename of the output graph
119125
"""
120-
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
121-
convert_dp12_to_dp13("frozen_model.pbtxt")
122-
convert_dp13_to_dp20("frozen_model.pbtxt")
123-
convert_dp20_to_dp21("frozen_model.pbtxt")
124-
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
125-
if os.path.isfile("frozen_model.pbtxt"):
126-
os.remove("frozen_model.pbtxt")
127-
print("the converted output model (2.1 support) is saved in %s" % output_model)
126+
convert_to_21(input_model, output_model, version="1.2")
128127

129128

130129
def convert_10_to_21(input_model: str, output_model: str):
@@ -137,15 +136,7 @@ def convert_10_to_21(input_model: str, output_model: str):
137136
output_model : str
138137
filename of the output graph
139138
"""
140-
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
141-
convert_dp10_to_dp11("frozen_model.pbtxt")
142-
convert_dp12_to_dp13("frozen_model.pbtxt")
143-
convert_dp13_to_dp20("frozen_model.pbtxt")
144-
convert_dp20_to_dp21("frozen_model.pbtxt")
145-
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
146-
if os.path.isfile("frozen_model.pbtxt"):
147-
os.remove("frozen_model.pbtxt")
148-
print("the converted output model (2.1 support) is saved in %s" % output_model)
139+
convert_to_21(input_model, output_model, version="1.0")
149140

150141

151142
def convert_012_to_21(input_model: str, output_model: str):
@@ -158,16 +149,7 @@ def convert_012_to_21(input_model: str, output_model: str):
158149
output_model : str
159150
filename of the output graph
160151
"""
161-
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
162-
convert_dp012_to_dp10("frozen_model.pbtxt")
163-
convert_dp10_to_dp11("frozen_model.pbtxt")
164-
convert_dp12_to_dp13("frozen_model.pbtxt")
165-
convert_dp13_to_dp20("frozen_model.pbtxt")
166-
convert_dp20_to_dp21("frozen_model.pbtxt")
167-
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
168-
if os.path.isfile("frozen_model.pbtxt"):
169-
os.remove("frozen_model.pbtxt")
170-
print("the converted output model (2.1 support) is saved in %s" % output_model)
152+
convert_to_21(input_model, output_model, version="0.12")
171153

172154

173155
def convert_20_to_21(input_model: str, output_model: str):
@@ -180,12 +162,7 @@ def convert_20_to_21(input_model: str, output_model: str):
180162
output_model : str
181163
filename of the output graph
182164
"""
183-
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
184-
convert_dp20_to_dp21("frozen_model.pbtxt")
185-
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
186-
if os.path.isfile("frozen_model.pbtxt"):
187-
os.remove("frozen_model.pbtxt")
188-
print("the converted output model (2.1 support) is saved in %s" % output_model)
165+
convert_to_21(input_model, output_model, version="2.0")
189166

190167

191168
def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str):

source/tests/test_deeppot_a.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
run_dp,
99
tests_path,
1010
)
11+
from packaging.version import parse as parse_version
1112

1213
from deepmd.env import (
1314
GLOBAL_NP_FLOAT_PRECISION,
@@ -750,33 +751,33 @@ def test_detect(self):
750751
new_model_pb = "deeppot_new.pb"
751752
convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model)
752753
version = detect_model_version(old_model)
753-
self.assertEqual(version, "<= 0.12")
754+
self.assertEqual(version, parse_version("0.12"))
754755
os.remove(old_model)
755756
shutil.copyfile(str(tests_path / "infer" / "sea_012.pbtxt"), new_model_txt)
756757
convert_dp012_to_dp10(new_model_txt)
757758
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
758759
version = detect_model_version(new_model_pb)
759-
self.assertEqual(version, "1.0")
760+
self.assertEqual(version, parse_version("1.0"))
760761
os.remove(new_model_pb)
761762
convert_dp10_to_dp11(new_model_txt)
762763
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
763764
version = detect_model_version(new_model_pb)
764-
self.assertEqual(version, "1.3")
765+
self.assertEqual(version, parse_version("1.3"))
765766
os.remove(new_model_pb)
766767
convert_dp12_to_dp13(new_model_txt)
767768
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
768769
version = detect_model_version(new_model_pb)
769-
self.assertEqual(version, "1.3")
770+
self.assertEqual(version, parse_version("1.3"))
770771
os.remove(new_model_pb)
771772
convert_dp13_to_dp20(new_model_txt)
772773
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
773774
version = detect_model_version(new_model_pb)
774-
self.assertEqual(version, "2.0")
775+
self.assertEqual(version, parse_version("2.0"))
775776
os.remove(new_model_pb)
776777
convert_dp20_to_dp21(new_model_txt)
777778
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
778779
version = detect_model_version(new_model_pb)
779-
self.assertEqual(version, ">= 2.1")
780+
self.assertEqual(version, parse_version("2.1"))
780781
os.remove(new_model_pb)
781782
os.remove(new_model_txt)
782783

0 commit comments

Comments
 (0)