Skip to content

Commit 1375b35

Browse files
authored
[Dy2ONNX] Improve dygraph2onnx API (#1549)
1 parent 6938791 commit 1375b35

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

paddle2onnx/convert.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from contextlib import contextmanager
2121
from paddle.decomposition import decomp
2222
from paddle.base.executor import global_scope
23+
import shutil
2324

2425

2526
def load_model(model_filename):
@@ -250,31 +251,32 @@ def export(
250251

251252

252253
def dygraph2onnx(layer, save_file, input_spec=None, opset_version=9, **configs):
253-
# Get PaddleInference model file path
254-
dirname = os.path.split(save_file)[0]
255-
paddle_model_dir = os.path.join(dirname, "paddle_model_temp_dir")
256-
model_file = os.path.join(paddle_model_dir, "model.pdmodel")
257-
params_file = os.path.join(paddle_model_dir, "model.pdiparams")
258-
259-
if os.path.exists(paddle_model_dir):
260-
if os.path.isfile(paddle_model_dir):
261-
logging.info("File {} exists, will remove it.".format(paddle_model_dir))
262-
os.remove(paddle_model_dir)
263-
if os.path.isfile(model_file):
264-
os.remove(model_file)
265-
if os.path.isfile(params_file):
266-
os.remove(params_file)
267-
save_configs = paddle_jit_save_configs(configs)
268-
with get_old_ir_guard()():
269-
# In PaddlePaddle 3.0.0b2, PIR becomes the default IR, but PIR export still in development.
270-
# So we need to use the old IR to export the model, avoid make users confused.
271-
# In the future, we will remove this guard and recommend users to use PIR.
254+
paddle_model_dir = tempfile.mkdtemp()
255+
try:
256+
save_configs = paddle_jit_save_configs(configs)
257+
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
258+
model_file = os.path.join(paddle_model_dir, "model.json")
259+
else:
260+
model_file = os.path.join(paddle_model_dir, "model.pdmodel")
272261
paddle.jit.save(
273262
layer, os.path.join(paddle_model_dir, "model"), input_spec, **save_configs
274263
)
275-
logging.info("Static PaddlePaddle model saved in {}.".format(paddle_model_dir))
276-
if not os.path.isfile(params_file):
277-
params_file = ""
278-
279-
export(model_file, params_file, save_file, opset_version)
280-
logging.info("ONNX model saved in {}.".format(save_file))
264+
if not os.path.isfile(model_file):
265+
raise ValueError("Failed to save static PaddlePaddle model.")
266+
logging.info("Static PaddlePaddle model saved in {}.".format(paddle_model_dir))
267+
params_file = os.path.join(paddle_model_dir, "model.pdiparams")
268+
if not os.path.isfile(params_file):
269+
params_file = ""
270+
export(model_file, params_file, save_file, opset_version)
271+
except Exception as err:
272+
logging.error(f"Failed to convert PaddlePaddle model due to {err}.")
273+
finally:
274+
if os.environ.get("P2O_KEEP_TEMP_MODEL", "0").lower() not in [
275+
"1",
276+
"true",
277+
"on",
278+
]:
279+
logging.warning(
280+
"Static PaddlePaddle model will be deleted, if you want to keep it, please set env variable `P2O_KEEP_TEMP_MODEL` to True."
281+
)
282+
shutil.rmtree(paddle_model_dir, ignore_errors=True)

0 commit comments

Comments
 (0)