55 convert_12_to_21 ,
66 convert_13_to_21 ,
77 convert_20_to_21 ,
8+ convert_pb_to_pbtxt ,
89 convert_pbtxt_to_pb ,
910 convert_to_21 ,
1011)
@@ -17,20 +18,26 @@ def convert(
1718 output_model : str ,
1819 ** kwargs ,
1920):
20- if FROM == "auto" :
21- convert_to_21 (input_model , output_model )
22- elif FROM == "0.12" :
23- convert_012_to_21 (input_model , output_model )
24- elif FROM == "1.0" :
25- convert_10_to_21 (input_model , output_model )
26- elif FROM in ["1.1" , "1.2" ]:
27- # no difference between 1.1 and 1.2
28- convert_12_to_21 (input_model , output_model )
29- elif FROM == "1.3" :
30- convert_13_to_21 (input_model , output_model )
31- elif FROM == "2.0" :
32- convert_20_to_21 (input_model , output_model )
33- elif FROM == "pbtxt" :
34- convert_pbtxt_to_pb (input_model , output_model )
21+ if output_model [- 6 :] == ".pbtxt" :
22+ if input_model [- 6 :] != ".pbtxt" :
23+ convert_pb_to_pbtxt (input_model , output_model )
24+ else :
25+ raise RuntimeError ("input model is already pbtxt" )
3526 else :
36- raise RuntimeError ("unsupported model version " + FROM )
27+ if FROM == "auto" :
28+ convert_to_21 (input_model , output_model )
29+ elif FROM == "0.12" :
30+ convert_012_to_21 (input_model , output_model )
31+ elif FROM == "1.0" :
32+ convert_10_to_21 (input_model , output_model )
33+ elif FROM in ["1.1" , "1.2" ]:
34+ # no difference between 1.1 and 1.2
35+ convert_12_to_21 (input_model , output_model )
36+ elif FROM == "1.3" :
37+ convert_13_to_21 (input_model , output_model )
38+ elif FROM == "2.0" :
39+ convert_20_to_21 (input_model , output_model )
40+ elif FROM == "pbtxt" :
41+ convert_pbtxt_to_pb (input_model , output_model )
42+ else :
43+ raise RuntimeError ("unsupported model version " + FROM )
0 commit comments