|
20 | 20 | from contextlib import contextmanager |
21 | 21 | from paddle.decomposition import decomp |
22 | 22 | from paddle.base.executor import global_scope |
| 23 | +import shutil |
23 | 24 |
|
24 | 25 |
|
25 | 26 | def load_model(model_filename): |
@@ -250,31 +251,32 @@ def export( |
250 | 251 |
|
251 | 252 |
|
252 | 253 | def dygraph2onnx(layer, save_file, input_spec=None, opset_version=9, **configs): |
253 | | - # Get PaddleInference model file path |
254 | | - dirname = os.path.split(save_file)[0] |
255 | | - paddle_model_dir = os.path.join(dirname, "paddle_model_temp_dir") |
256 | | - model_file = os.path.join(paddle_model_dir, "model.pdmodel") |
257 | | - params_file = os.path.join(paddle_model_dir, "model.pdiparams") |
258 | | - |
259 | | - if os.path.exists(paddle_model_dir): |
260 | | - if os.path.isfile(paddle_model_dir): |
261 | | - logging.info("File {} exists, will remove it.".format(paddle_model_dir)) |
262 | | - os.remove(paddle_model_dir) |
263 | | - if os.path.isfile(model_file): |
264 | | - os.remove(model_file) |
265 | | - if os.path.isfile(params_file): |
266 | | - os.remove(params_file) |
267 | | - save_configs = paddle_jit_save_configs(configs) |
268 | | - with get_old_ir_guard()(): |
269 | | - # In PaddlePaddle 3.0.0b2, PIR becomes the default IR, but PIR export still in development. |
270 | | - # So we need to use the old IR to export the model, avoid make users confused. |
271 | | - # In the future, we will remove this guard and recommend users to use PIR. |
| 254 | + paddle_model_dir = tempfile.mkdtemp() |
| 255 | + try: |
| 256 | + save_configs = paddle_jit_save_configs(configs) |
| 257 | + if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]: |
| 258 | + model_file = os.path.join(paddle_model_dir, "model.json") |
| 259 | + else: |
| 260 | + model_file = os.path.join(paddle_model_dir, "model.pdmodel") |
272 | 261 | paddle.jit.save( |
273 | 262 | layer, os.path.join(paddle_model_dir, "model"), input_spec, **save_configs |
274 | 263 | ) |
275 | | - logging.info("Static PaddlePaddle model saved in {}.".format(paddle_model_dir)) |
276 | | - if not os.path.isfile(params_file): |
277 | | - params_file = "" |
278 | | - |
279 | | - export(model_file, params_file, save_file, opset_version) |
280 | | - logging.info("ONNX model saved in {}.".format(save_file)) |
| 264 | + if not os.path.isfile(model_file): |
| 265 | + raise ValueError("Failed to save static PaddlePaddle model.") |
| 266 | + logging.info("Static PaddlePaddle model saved in {}.".format(paddle_model_dir)) |
| 267 | + params_file = os.path.join(paddle_model_dir, "model.pdiparams") |
| 268 | + if not os.path.isfile(params_file): |
| 269 | + params_file = "" |
| 270 | + export(model_file, params_file, save_file, opset_version) |
| 271 | + except Exception as err: |
| 272 | + logging.error(f"Failed to convert PaddlePaddle model due to {err}.") |
| 273 | + finally: |
| 274 | + if os.environ.get("P2O_KEEP_TEMP_MODEL", "0").lower() not in [ |
| 275 | + "1", |
| 276 | + "true", |
| 277 | + "on", |
| 278 | + ]: |
| 279 | + logging.warning( |
| 280 | + "Static PaddlePaddle model will be deleted, if you want to keep it, please set env variable `P2O_KEEP_TEMP_MODEL` to True." |
| 281 | + ) |
| 282 | + shutil.rmtree(paddle_model_dir, ignore_errors=True) |
0 commit comments