11# SPDX-License-Identifier: LGPL-3.0-or-later
2+ import logging
23import os
34import textwrap
5+ from typing import (
6+ Optional ,
7+ )
48
59from google .protobuf import (
610 text_format ,
711)
12+ from packaging .specifiers import (
13+ SpecifierSet ,
14+ )
15+ from packaging .version import parse as parse_version
816
17+ from deepmd import (
18+ __version__ ,
19+ )
920from deepmd .env import (
1021 tf ,
1122)
1223
24+ log = logging .getLogger (__name__ )
25+
1326
1427def detect_model_version (input_model : str ):
1528 """Detect DP graph version.
@@ -20,33 +33,33 @@ def detect_model_version(input_model: str):
2033 filename of the input graph
2134 """
2235 convert_pb_to_pbtxt (input_model , "frozen_model.pbtxt" )
23- version = "undetected"
36+ version = None
2437 with open ("frozen_model.pbtxt" ) as fp :
2538 file_content = fp .read ()
2639 if file_content .find ("DescrptNorot" ) > - 1 :
27- version = "<= 0.12"
40+ version = parse_version ( " 0.12")
2841 elif (
2942 file_content .find ("fitting_attr/dfparam" ) > - 1
3043 and file_content .find ("fitting_attr/daparam" ) == - 1
3144 ):
32- version = "1.0"
45+ version = parse_version ( "1.0" )
3346 elif file_content .find ("model_attr/model_version" ) == - 1 :
3447 name_dsea = file_content .find ('name: "DescrptSeA"' )
3548 post_dsea = file_content [name_dsea :]
3649 post_dsea2 = post_dsea [:300 ].find (r"}" )
3750 search_double = post_dsea [:post_dsea2 ]
3851 if search_double .find ("DT_DOUBLE" ) == - 1 :
39- version = "1.2"
52+ version = parse_version ( "1.2" )
4053 else :
41- version = "1.3"
54+ version = parse_version ( "1.3" )
4255 elif file_content .find ('string_val: "1.0"' ) > - 1 :
43- version = "2.0"
56+ version = parse_version ( "2.0" )
4457 elif file_content .find ('string_val: "1.1"' ) > - 1 :
45- version = ">= 2.1"
58+ version = parse_version ( " 2.1")
4659 return version
4760
4861
49- def convert_to_21 (input_model : str , output_model : str ):
62+ def convert_to_21 (input_model : str , output_model : str , version : Optional [ str ] = None ):
5063 """Convert DP graph to 2.1 graph.
5164
5265 Parameters
@@ -55,37 +68,36 @@ def convert_to_21(input_model: str, output_model: str):
5568 filename of the input graph
5669 output_model : str
5770 filename of the output graph
71+ version : str
72+ version of the input graph, if not specified, it will be detected automatically
5873 """
59- version = detect_model_version (input_model )
60- if version == "<= 0.12" :
74+ if version is None :
75+ version = detect_model_version (input_model )
76+ else :
77+ convert_pb_to_pbtxt (input_model , "frozen_model.pbtxt" )
78+ if version is None :
79+ raise ValueError (
80+ "The version of the DP graph %s cannot be detected. Please do the conversion manually."
81+ % (input_model )
82+ )
83+ if version in SpecifierSet ("<1.0" ):
6184 convert_dp012_to_dp10 ("frozen_model.pbtxt" )
85+ if version in SpecifierSet ("<1.1" ):
6286 convert_dp10_to_dp11 ("frozen_model.pbtxt" )
87+ if version in SpecifierSet ("<1.3" ):
6388 convert_dp12_to_dp13 ("frozen_model.pbtxt" )
89+ if version in SpecifierSet ("<2.0" ):
6490 convert_dp13_to_dp20 ("frozen_model.pbtxt" )
91+ if version in SpecifierSet ("<2.1" ):
6592 convert_dp20_to_dp21 ("frozen_model.pbtxt" )
66- elif version == "1.0" :
67- convert_dp10_to_dp11 ("frozen_model.pbtxt" )
68- convert_dp12_to_dp13 ("frozen_model.pbtxt" )
69- convert_dp13_to_dp20 ("frozen_model.pbtxt" )
70- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
71- elif version == "1.2" :
72- convert_dp12_to_dp13 ("frozen_model.pbtxt" )
73- convert_dp13_to_dp20 ("frozen_model.pbtxt" )
74- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
75- elif version == "1.3" :
76- convert_dp13_to_dp20 ("frozen_model.pbtxt" )
77- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
78- elif version == "2.0" :
79- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
80- elif version == "undetected" :
81- raise ValueError (
82- "The version of the DP graph %s cannot be detected. Please do the conversion manually."
83- % (input_model )
84- )
8593 convert_pbtxt_to_pb ("frozen_model.pbtxt" , output_model )
8694 if os .path .isfile ("frozen_model.pbtxt" ):
8795 os .remove ("frozen_model.pbtxt" )
88- print ("the converted output model (2.1 support) is saved in %s" % output_model )
96+ log .info (
97+ "the converted output model (%s support) is saved in %s" ,
98+ __version__ ,
99+ output_model ,
100+ )
89101
90102
91103def convert_13_to_21 (input_model : str , output_model : str ):
@@ -98,13 +110,7 @@ def convert_13_to_21(input_model: str, output_model: str):
98110 output_model : str
99111 filename of the output graph
100112 """
101- convert_pb_to_pbtxt (input_model , "frozen_model.pbtxt" )
102- convert_dp13_to_dp20 ("frozen_model.pbtxt" )
103- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
104- convert_pbtxt_to_pb ("frozen_model.pbtxt" , output_model )
105- if os .path .isfile ("frozen_model.pbtxt" ):
106- os .remove ("frozen_model.pbtxt" )
107- print ("the converted output model (2.1 support) is saved in %s" % output_model )
113+ convert_to_21 (input_model , output_model , version = "1.3" )
108114
109115
110116def convert_12_to_21 (input_model : str , output_model : str ):
@@ -117,14 +123,7 @@ def convert_12_to_21(input_model: str, output_model: str):
117123 output_model : str
118124 filename of the output graph
119125 """
120- convert_pb_to_pbtxt (input_model , "frozen_model.pbtxt" )
121- convert_dp12_to_dp13 ("frozen_model.pbtxt" )
122- convert_dp13_to_dp20 ("frozen_model.pbtxt" )
123- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
124- convert_pbtxt_to_pb ("frozen_model.pbtxt" , output_model )
125- if os .path .isfile ("frozen_model.pbtxt" ):
126- os .remove ("frozen_model.pbtxt" )
127- print ("the converted output model (2.1 support) is saved in %s" % output_model )
126+ convert_to_21 (input_model , output_model , version = "1.2" )
128127
129128
130129def convert_10_to_21 (input_model : str , output_model : str ):
@@ -137,15 +136,7 @@ def convert_10_to_21(input_model: str, output_model: str):
137136 output_model : str
138137 filename of the output graph
139138 """
140- convert_pb_to_pbtxt (input_model , "frozen_model.pbtxt" )
141- convert_dp10_to_dp11 ("frozen_model.pbtxt" )
142- convert_dp12_to_dp13 ("frozen_model.pbtxt" )
143- convert_dp13_to_dp20 ("frozen_model.pbtxt" )
144- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
145- convert_pbtxt_to_pb ("frozen_model.pbtxt" , output_model )
146- if os .path .isfile ("frozen_model.pbtxt" ):
147- os .remove ("frozen_model.pbtxt" )
148- print ("the converted output model (2.1 support) is saved in %s" % output_model )
139+ convert_to_21 (input_model , output_model , version = "1.0" )
149140
150141
151142def convert_012_to_21 (input_model : str , output_model : str ):
@@ -158,16 +149,7 @@ def convert_012_to_21(input_model: str, output_model: str):
158149 output_model : str
159150 filename of the output graph
160151 """
161- convert_pb_to_pbtxt (input_model , "frozen_model.pbtxt" )
162- convert_dp012_to_dp10 ("frozen_model.pbtxt" )
163- convert_dp10_to_dp11 ("frozen_model.pbtxt" )
164- convert_dp12_to_dp13 ("frozen_model.pbtxt" )
165- convert_dp13_to_dp20 ("frozen_model.pbtxt" )
166- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
167- convert_pbtxt_to_pb ("frozen_model.pbtxt" , output_model )
168- if os .path .isfile ("frozen_model.pbtxt" ):
169- os .remove ("frozen_model.pbtxt" )
170- print ("the converted output model (2.1 support) is saved in %s" % output_model )
152+ convert_to_21 (input_model , output_model , version = "0.12" )
171153
172154
173155def convert_20_to_21 (input_model : str , output_model : str ):
@@ -180,12 +162,7 @@ def convert_20_to_21(input_model: str, output_model: str):
180162 output_model : str
181163 filename of the output graph
182164 """
183- convert_pb_to_pbtxt (input_model , "frozen_model.pbtxt" )
184- convert_dp20_to_dp21 ("frozen_model.pbtxt" )
185- convert_pbtxt_to_pb ("frozen_model.pbtxt" , output_model )
186- if os .path .isfile ("frozen_model.pbtxt" ):
187- os .remove ("frozen_model.pbtxt" )
188- print ("the converted output model (2.1 support) is saved in %s" % output_model )
165+ convert_to_21 (input_model , output_model , version = "2.0" )
189166
190167
191168def convert_pb_to_pbtxt (pbfile : str , pbtxtfile : str ):
0 commit comments