2121from paddle .decomposition import decomp
2222from paddle .base .executor import global_scope
2323import shutil
24+ import traceback
2425
26+ PADDLE2ONNX_EXPORT_TEMP_DIR = None
2527
26- def load_model (model_filename ):
27- """Loads the pir model from json file."""
28- assert os .path .exists (
29- model_filename
30- ), f"Model file { model_filename } does not exist."
31- if model_filename .endswith (".json" ):
32- model_filename = model_filename [:- 5 ]
33- return paddle .jit .load (model_filename )
28+
29+ def get_tmp_dir_and_file (model_filename , suffix = "" ):
30+ global PADDLE2ONNX_EXPORT_TEMP_DIR
31+ if PADDLE2ONNX_EXPORT_TEMP_DIR is None :
32+ PADDLE2ONNX_EXPORT_TEMP_DIR = tempfile .mkdtemp ()
33+ model_file_path , _ = os .path .splitext (model_filename )
34+ new_model_file_path = os .path .join (
35+ PADDLE2ONNX_EXPORT_TEMP_DIR , os .path .basename (model_file_path ) + suffix
36+ )
37+ new_model_file_name = new_model_file_path + ".json"
38+ new_params_file_name = new_model_file_path + ".pdiparams"
39+ return (
40+ model_file_path ,
41+ new_model_file_path ,
42+ new_model_file_name ,
43+ new_params_file_name ,
44+ )
3445
3546
3647def compare_programs (original_program , new_program ):
@@ -40,33 +51,21 @@ def compare_programs(original_program, new_program):
4051 return original_ops == new_ops
4152
4253
43- def save_program (program , model_file ):
44- """Saves the decomposed program to a file."""
54+ def save_program (program , new_model_file_path ):
4555 place = paddle .CPUPlace ()
4656 exe = paddle .static .Executor (place )
47-
48- tmp_dir = tempfile .mkdtemp ()
49- filename = os .path .basename (model_file ) + "_decompose"
50- filename_without_extension , _ = os .path .splitext (filename )
51- save_dir = os .path .join (tmp_dir , filename_without_extension )
52-
5357 # Find feed and fetch operations
5458 feed , fetch = [], []
59+ # TODO(wangmingkai02): need double check it
5560 for op in program .global_block ().ops :
5661 if op .name () == "pd_op.feed" :
5762 feed .extend (op .results ())
5863 if op .name () == "pd_op.fetch" or op .name () == "builtin.shadow_output" :
5964 fetch .extend (op .operands_source ())
60-
6165 with paddle .pir_utils .IrGuard ():
62- paddle .static .save_inference_model (save_dir , feed , fetch , exe , program = program )
63-
64- new_model_file = save_dir + ".json"
65- assert os .path .exists (
66- new_model_file
67- ), f"Pir Model file { new_model_file } does not exist."
68- logging .info (f"Decomposed Model file path: { new_model_file } " )
69- return new_model_file
66+ paddle .static .save_inference_model (
67+ new_model_file_path , feed , fetch , exe , program = program
68+ )
7069
7170
7271def load_parameter (program ):
@@ -79,10 +78,8 @@ def load_parameter(program):
7978 opts .append (var )
8079 vars_list = params + opts
8180 vars = [var for var in vars_list if var .persistable ]
82-
8381 if vars is None :
8482 return
85-
8683 place = paddle .CPUPlace ()
8784 exe = paddle .static .Executor (place )
8885 paddle .base .libpaddle .pir .create_loaded_parameter (
@@ -92,7 +89,10 @@ def load_parameter(program):
9289
9390def decompose_program (model_filename ):
9491 """Decomposes the given pir program."""
95- model = load_model (model_filename )
92+ model_file_path , new_model_file_path , new_model_file_name , new_params_file_name = (
93+ get_tmp_dir_and_file (model_filename , "_decompose" )
94+ )
95+ model = paddle .jit .load (model_file_path )
9696 new_program = model .program ().clone ()
9797 with decomp .prim_guard ():
9898 decomp .decompose_dist_program (new_program )
@@ -101,7 +101,8 @@ def decompose_program(model_filename):
101101 return model_filename
102102
103103 load_parameter (new_program )
104- return save_program (new_program , model_filename )
104+ save_program (new_program , new_model_file_path )
105+ return new_model_file_name
105106
106107
107108def get_old_ir_guard ():
@@ -136,91 +137,129 @@ def export(
136137 export_fp16_model = False ,
137138 enable_polygraphy = True ,
138139):
140+ global PADDLE2ONNX_EXPORT_TEMP_DIR
139141 # check model_filename
140142 assert os .path .exists (
141143 model_filename
142144 ), f"Model file { model_filename } does not exist."
145+ if not os .path .exists (params_filename ):
146+ logging .warning (
147+ f"Params file { params_filename } does not exist, "
148+ + "the exported onnx model will not contain weights."
149+ )
150+ params_filename = ""
143151
144- # translate old ir program to pir
145- tmp_dir = tempfile .mkdtemp ()
146- dir_and_file , extension = os .path .splitext (model_filename )
147- filename = os .path .basename (model_filename )
148- filename_without_extension , _ = os .path .splitext (filename )
149- save_dir = os .path .join (tmp_dir , filename_without_extension )
150- if model_filename .endswith (".pdmodel" ):
151- if os .path .exists (model_filename ) and os .path .exists (params_filename ):
152- place = paddle .CPUPlace ()
153- exe = paddle .static .Executor (place )
154- with paddle .pir_utils .OldIrGuard ():
155- [inference_program , feed_target_names , fetch_targets ] = (
156- paddle .static .load_inference_model (dir_and_file , exe )
157- )
158- program = paddle .pir .translate_to_pir (inference_program .desc )
159- for op in program .global_block ().ops :
160- if op .name () == "pd_op.feed" :
161- feed = op .results ()
162- if op .name () == "pd_op.fetch" :
163- fetch = op .operands_source ()
164- with paddle .pir_utils .IrGuard ():
165- paddle .static .save_inference_model (
166- save_dir , feed , fetch , exe , program = program
152+ try :
153+ if model_filename .endswith (".pdmodel" ):
154+ # translate old ir program to pir program
155+ logging .warning (
156+ "The .pdmodel file is deprecated in paddlepaddle 3.0"
157+ + " and will be removed in the future."
158+ + " Try to convert from .pdmodel file to json file."
159+ )
160+ (
161+ model_file_path ,
162+ new_model_file_path ,
163+ new_model_file_name ,
164+ new_params_file_name ,
165+ ) = get_tmp_dir_and_file (model_filename , "_pt" )
166+ if os .path .exists (params_filename ):
167+ place = paddle .CPUPlace ()
168+ exe = paddle .static .Executor (place )
169+ with paddle .pir_utils .OldIrGuard ():
170+ [inference_program , feed_target_names , fetch_targets ] = (
171+ paddle .static .load_inference_model (model_file_path , exe )
172+ )
173+ program = paddle .pir .translate_to_pir (inference_program .desc )
174+ # TODO(wangmingkai02): Do we need to call load_parameter(program) here?
175+ load_parameter (program )
176+ save_program (program , new_model_file_path )
177+ params_filename = new_params_file_name
178+ if not os .path .exists (new_params_file_name ):
179+ raise RuntimeError (
180+ f"Program Tranlator failed due to params file { new_params_file_name } does not exist."
181+ )
182+ else :
183+ with paddle .pir_utils .OldIrGuard ():
184+ program = paddle .load (model_filename )
185+ pir_program = paddle .pir .translate_to_pir (program .desc )
186+ with paddle .pir_utils .IrGuard ():
187+ paddle .save (pir_program , new_model_file_name )
188+ if not os .path .exists (new_model_file_name ):
189+ raise RuntimeError (
190+ f"Program Tranlator failed due to json file { new_model_file_name } does not exist."
167191 )
168- model_filename = save_dir + ".json"
169- params_filename = save_dir + ".pdiparams"
170- assert os .path .exists (
171- model_filename
172- ), f"Pir Model file { model_filename } does not exist."
173- assert os .path .exists (
174- params_filename
175- ), f"Pir Params file { params_filename } does not exist."
192+ model_filename = new_model_file_name
193+ if verbose :
194+ logging .info ("Complete the conversion from .pdmodel to json file." )
195+
196+ if paddle .get_flags ("FLAGS_enable_pir_api" )["FLAGS_enable_pir_api" ]:
197+ if dist_prim_all and auto_upgrade_opset :
198+ if verbose :
199+ logging .info ("Try to decompose program ..." )
200+ # TODO(wangmingkai02): Do we need to update params_filename here?
201+ model_filename = decompose_program (model_filename )
202+ if verbose :
203+ logging .info ("Complete the decomposition of combined operators." )
204+
205+ if verbose and PADDLE2ONNX_EXPORT_TEMP_DIR is not None :
206+ logging .info (
207+ f"Intermediate model and param files are saved at { PADDLE2ONNX_EXPORT_TEMP_DIR } "
208+ )
209+
210+ deploy_backend = deploy_backend .lower ()
211+ if custom_op_info is None :
212+ onnx_model_str = c_p2o .export (
213+ model_filename ,
214+ params_filename ,
215+ opset_version ,
216+ auto_upgrade_opset ,
217+ verbose ,
218+ enable_onnx_checker ,
219+ enable_experimental_op ,
220+ enable_optimize ,
221+ {},
222+ deploy_backend ,
223+ calibration_file ,
224+ external_file ,
225+ export_fp16_model ,
226+ )
176227 else :
177- with paddle .pir_utils .OldIrGuard ():
178- program = paddle .load (model_filename )
179- pir_program = paddle .pir .translate_to_pir (program .desc )
180- save_dir = os .path .join (tmp_dir , filename_without_extension )
181- model_filename = save_dir + ".json"
182- with paddle .pir_utils .IrGuard ():
183- paddle .save (pir_program , model_filename )
184- assert os .path .exists (
185- model_filename
186- ), f"Pir Model file { model_filename } does not exist."
187- if paddle .get_flags ("FLAGS_enable_pir_api" )["FLAGS_enable_pir_api" ]:
188- if dist_prim_all and auto_upgrade_opset :
189- model_filename = decompose_program (model_filename )
228+ onnx_model_str = c_p2o .export (
229+ model_filename ,
230+ params_filename ,
231+ opset_version ,
232+ auto_upgrade_opset ,
233+ verbose ,
234+ enable_onnx_checker ,
235+ enable_experimental_op ,
236+ enable_optimize ,
237+ custom_op_info ,
238+ deploy_backend ,
239+ calibration_file ,
240+ external_file ,
241+ export_fp16_model ,
242+ )
243+ except Exception as error :
244+ logging .error (f"Failed to convert PaddlePaddle model: { error } ." )
245+ logging .error (traceback .print_exc ())
246+ finally :
247+ if (
248+ os .environ .get ("P2O_KEEP_TEMP_MODEL" , "0" ).lower ()
249+ not in [
250+ "1" ,
251+ "true" ,
252+ "on" ,
253+ ]
254+ and PADDLE2ONNX_EXPORT_TEMP_DIR is not None
255+ ):
256+ logging .warning (
257+ "Intermediate model and param files will be deleted,"
258+ + " if you want to keep them, please set env variable `P2O_KEEP_TEMP_MODEL` to True."
259+ )
260+ shutil .rmtree (PADDLE2ONNX_EXPORT_TEMP_DIR , ignore_errors = True )
261+ PADDLE2ONNX_EXPORT_TEMP_DIR = None
190262
191- deploy_backend = deploy_backend .lower ()
192- if custom_op_info is None :
193- onnx_model_str = c_p2o .export (
194- model_filename ,
195- params_filename ,
196- opset_version ,
197- auto_upgrade_opset ,
198- verbose ,
199- enable_onnx_checker ,
200- enable_experimental_op ,
201- enable_optimize ,
202- {},
203- deploy_backend ,
204- calibration_file ,
205- external_file ,
206- export_fp16_model ,
207- )
208- else :
209- onnx_model_str = c_p2o .export (
210- model_filename ,
211- params_filename ,
212- opset_version ,
213- auto_upgrade_opset ,
214- verbose ,
215- enable_onnx_checker ,
216- enable_experimental_op ,
217- enable_optimize ,
218- custom_op_info ,
219- deploy_backend ,
220- calibration_file ,
221- external_file ,
222- export_fp16_model ,
223- )
224263 if save_file is not None :
225264 if enable_polygraphy :
226265 try :
@@ -277,6 +316,7 @@ def dygraph2onnx(layer, save_file, input_spec=None, opset_version=9, **configs):
277316 "on" ,
278317 ]:
279318 logging .warning (
280- "Static PaddlePaddle model will be deleted, if you want to keep it, please set env variable `P2O_KEEP_TEMP_MODEL` to True."
319+ "Static PaddlePaddle model will be deleted, if you want to keep it,"
320+ + " please set env variable `P2O_KEEP_TEMP_MODEL` to True."
281321 )
282322 shutil .rmtree (paddle_model_dir , ignore_errors = True )
0 commit comments