Skip to content

Commit 55e4c19

Browse files
authored
[PT | Decomposition] Improve program translation and decomposition (#1551)
1 parent 71cb612 commit 55e4c19

File tree

3 files changed

+165
-113
lines changed

3 files changed

+165
-113
lines changed

debug/p2o_infer_debugger.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from contextlib import contextmanager
2727
import traceback
2828
import queue
29+
import tempfile
2930

3031
current_dir = os.path.dirname(os.path.abspath(__file__))
3132
tests_dir = os.path.join(current_dir, "..", "tests")
@@ -165,11 +166,16 @@ def str2dtype(dtype: str):
165166
return tuple(inputs)
166167

167168

168-
def save_and_export(program, model_file):
169+
def save_and_export(program, model_file_path):
170+
temp_dir = tempfile.mkdtemp()
171+
new_model_file_path = os.path.join(
172+
temp_dir, os.path.basename(model_file_path) + "_debug"
173+
)
174+
new_model_file = new_model_file_path + ".json"
175+
new_params_file = new_model_file_path + ".pdiparams"
169176
paddle2onnx.load_parameter(program)
170-
new_model_file = paddle2onnx.save_program(program, model_file)
171-
origin_params_file = os.path.splitext(model_file)[0] + ".pdiparams"
172-
new_params_file = os.path.splitext(new_model_file)[0] + ".pdiparams"
177+
paddle2onnx.save_program(program, new_model_file_path)
178+
origin_params_file = os.path.splitext(model_file_path)[0] + ".pdiparams"
173179
if os.path.exists(origin_params_file):
174180
shutil.copy(origin_params_file, new_params_file)
175181
if not os.path.exists(new_params_file):

paddle2onnx/convert.py

Lines changed: 147 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,27 @@
2121
from paddle.decomposition import decomp
2222
from paddle.base.executor import global_scope
2323
import shutil
24+
import traceback
2425

26+
PADDLE2ONNX_EXPORT_TEMP_DIR = None
2527

26-
def load_model(model_filename):
27-
"""Loads the pir model from json file."""
28-
assert os.path.exists(
29-
model_filename
30-
), f"Model file {model_filename} does not exist."
31-
if model_filename.endswith(".json"):
32-
model_filename = model_filename[:-5]
33-
return paddle.jit.load(model_filename)
28+
29+
def get_tmp_dir_and_file(model_filename, suffix=""):
30+
global PADDLE2ONNX_EXPORT_TEMP_DIR
31+
if PADDLE2ONNX_EXPORT_TEMP_DIR is None:
32+
PADDLE2ONNX_EXPORT_TEMP_DIR = tempfile.mkdtemp()
33+
model_file_path, _ = os.path.splitext(model_filename)
34+
new_model_file_path = os.path.join(
35+
PADDLE2ONNX_EXPORT_TEMP_DIR, os.path.basename(model_file_path) + suffix
36+
)
37+
new_model_file_name = new_model_file_path + ".json"
38+
new_params_file_name = new_model_file_path + ".pdiparams"
39+
return (
40+
model_file_path,
41+
new_model_file_path,
42+
new_model_file_name,
43+
new_params_file_name,
44+
)
3445

3546

3647
def compare_programs(original_program, new_program):
@@ -40,33 +51,21 @@ def compare_programs(original_program, new_program):
4051
return original_ops == new_ops
4152

4253

43-
def save_program(program, model_file):
44-
"""Saves the decomposed program to a file."""
54+
def save_program(program, new_model_file_path):
4555
place = paddle.CPUPlace()
4656
exe = paddle.static.Executor(place)
47-
48-
tmp_dir = tempfile.mkdtemp()
49-
filename = os.path.basename(model_file) + "_decompose"
50-
filename_without_extension, _ = os.path.splitext(filename)
51-
save_dir = os.path.join(tmp_dir, filename_without_extension)
52-
5357
# Find feed and fetch operations
5458
feed, fetch = [], []
59+
# TODO(wangmingkai02): need double check it
5560
for op in program.global_block().ops:
5661
if op.name() == "pd_op.feed":
5762
feed.extend(op.results())
5863
if op.name() == "pd_op.fetch" or op.name() == "builtin.shadow_output":
5964
fetch.extend(op.operands_source())
60-
6165
with paddle.pir_utils.IrGuard():
62-
paddle.static.save_inference_model(save_dir, feed, fetch, exe, program=program)
63-
64-
new_model_file = save_dir + ".json"
65-
assert os.path.exists(
66-
new_model_file
67-
), f"Pir Model file {new_model_file} does not exist."
68-
logging.info(f"Decomposed Model file path: {new_model_file}")
69-
return new_model_file
66+
paddle.static.save_inference_model(
67+
new_model_file_path, feed, fetch, exe, program=program
68+
)
7069

7170

7271
def load_parameter(program):
@@ -79,10 +78,8 @@ def load_parameter(program):
7978
opts.append(var)
8079
vars_list = params + opts
8180
vars = [var for var in vars_list if var.persistable]
82-
8381
if vars is None:
8482
return
85-
8683
place = paddle.CPUPlace()
8784
exe = paddle.static.Executor(place)
8885
paddle.base.libpaddle.pir.create_loaded_parameter(
@@ -92,7 +89,10 @@ def load_parameter(program):
9289

9390
def decompose_program(model_filename):
9491
"""Decomposes the given pir program."""
95-
model = load_model(model_filename)
92+
model_file_path, new_model_file_path, new_model_file_name, new_params_file_name = (
93+
get_tmp_dir_and_file(model_filename, "_decompose")
94+
)
95+
model = paddle.jit.load(model_file_path)
9696
new_program = model.program().clone()
9797
with decomp.prim_guard():
9898
decomp.decompose_dist_program(new_program)
@@ -101,7 +101,8 @@ def decompose_program(model_filename):
101101
return model_filename
102102

103103
load_parameter(new_program)
104-
return save_program(new_program, model_filename)
104+
save_program(new_program, new_model_file_path)
105+
return new_model_file_name
105106

106107

107108
def get_old_ir_guard():
@@ -136,91 +137,129 @@ def export(
136137
export_fp16_model=False,
137138
enable_polygraphy=True,
138139
):
140+
global PADDLE2ONNX_EXPORT_TEMP_DIR
139141
# check model_filename
140142
assert os.path.exists(
141143
model_filename
142144
), f"Model file {model_filename} does not exist."
145+
if not os.path.exists(params_filename):
146+
logging.warning(
147+
f"Params file {params_filename} does not exist, "
148+
+ "the exported onnx model will not contain weights."
149+
)
150+
params_filename = ""
143151

144-
# translate old ir program to pir
145-
tmp_dir = tempfile.mkdtemp()
146-
dir_and_file, extension = os.path.splitext(model_filename)
147-
filename = os.path.basename(model_filename)
148-
filename_without_extension, _ = os.path.splitext(filename)
149-
save_dir = os.path.join(tmp_dir, filename_without_extension)
150-
if model_filename.endswith(".pdmodel"):
151-
if os.path.exists(model_filename) and os.path.exists(params_filename):
152-
place = paddle.CPUPlace()
153-
exe = paddle.static.Executor(place)
154-
with paddle.pir_utils.OldIrGuard():
155-
[inference_program, feed_target_names, fetch_targets] = (
156-
paddle.static.load_inference_model(dir_and_file, exe)
157-
)
158-
program = paddle.pir.translate_to_pir(inference_program.desc)
159-
for op in program.global_block().ops:
160-
if op.name() == "pd_op.feed":
161-
feed = op.results()
162-
if op.name() == "pd_op.fetch":
163-
fetch = op.operands_source()
164-
with paddle.pir_utils.IrGuard():
165-
paddle.static.save_inference_model(
166-
save_dir, feed, fetch, exe, program=program
152+
try:
153+
if model_filename.endswith(".pdmodel"):
154+
# translate old ir program to pir program
155+
logging.warning(
156+
"The .pdmodel file is deprecated in paddlepaddle 3.0"
157+
+ " and will be removed in the future."
158+
+ " Try to convert from .pdmodel file to json file."
159+
)
160+
(
161+
model_file_path,
162+
new_model_file_path,
163+
new_model_file_name,
164+
new_params_file_name,
165+
) = get_tmp_dir_and_file(model_filename, "_pt")
166+
if os.path.exists(params_filename):
167+
place = paddle.CPUPlace()
168+
exe = paddle.static.Executor(place)
169+
with paddle.pir_utils.OldIrGuard():
170+
[inference_program, feed_target_names, fetch_targets] = (
171+
paddle.static.load_inference_model(model_file_path, exe)
172+
)
173+
program = paddle.pir.translate_to_pir(inference_program.desc)
174+
# TODO(wangmingkai02): Do we need to call load_parameter(program) here?
175+
load_parameter(program)
176+
save_program(program, new_model_file_path)
177+
params_filename = new_params_file_name
178+
if not os.path.exists(new_params_file_name):
179+
raise RuntimeError(
180+
f"Program Tranlator failed due to params file {new_params_file_name} does not exist."
181+
)
182+
else:
183+
with paddle.pir_utils.OldIrGuard():
184+
program = paddle.load(model_filename)
185+
pir_program = paddle.pir.translate_to_pir(program.desc)
186+
with paddle.pir_utils.IrGuard():
187+
paddle.save(pir_program, new_model_file_name)
188+
if not os.path.exists(new_model_file_name):
189+
raise RuntimeError(
190+
f"Program Tranlator failed due to json file {new_model_file_name} does not exist."
167191
)
168-
model_filename = save_dir + ".json"
169-
params_filename = save_dir + ".pdiparams"
170-
assert os.path.exists(
171-
model_filename
172-
), f"Pir Model file {model_filename} does not exist."
173-
assert os.path.exists(
174-
params_filename
175-
), f"Pir Params file {params_filename} does not exist."
192+
model_filename = new_model_file_name
193+
if verbose:
194+
logging.info("Complete the conversion from .pdmodel to json file.")
195+
196+
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
197+
if dist_prim_all and auto_upgrade_opset:
198+
if verbose:
199+
logging.info("Try to decompose program ...")
200+
# TODO(wangmingkai02): Do we need to update params_filename here?
201+
model_filename = decompose_program(model_filename)
202+
if verbose:
203+
logging.info("Complete the decomposition of combined operators.")
204+
205+
if verbose and PADDLE2ONNX_EXPORT_TEMP_DIR is not None:
206+
logging.info(
207+
f"Intermediate model and param files are saved at {PADDLE2ONNX_EXPORT_TEMP_DIR}"
208+
)
209+
210+
deploy_backend = deploy_backend.lower()
211+
if custom_op_info is None:
212+
onnx_model_str = c_p2o.export(
213+
model_filename,
214+
params_filename,
215+
opset_version,
216+
auto_upgrade_opset,
217+
verbose,
218+
enable_onnx_checker,
219+
enable_experimental_op,
220+
enable_optimize,
221+
{},
222+
deploy_backend,
223+
calibration_file,
224+
external_file,
225+
export_fp16_model,
226+
)
176227
else:
177-
with paddle.pir_utils.OldIrGuard():
178-
program = paddle.load(model_filename)
179-
pir_program = paddle.pir.translate_to_pir(program.desc)
180-
save_dir = os.path.join(tmp_dir, filename_without_extension)
181-
model_filename = save_dir + ".json"
182-
with paddle.pir_utils.IrGuard():
183-
paddle.save(pir_program, model_filename)
184-
assert os.path.exists(
185-
model_filename
186-
), f"Pir Model file {model_filename} does not exist."
187-
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
188-
if dist_prim_all and auto_upgrade_opset:
189-
model_filename = decompose_program(model_filename)
228+
onnx_model_str = c_p2o.export(
229+
model_filename,
230+
params_filename,
231+
opset_version,
232+
auto_upgrade_opset,
233+
verbose,
234+
enable_onnx_checker,
235+
enable_experimental_op,
236+
enable_optimize,
237+
custom_op_info,
238+
deploy_backend,
239+
calibration_file,
240+
external_file,
241+
export_fp16_model,
242+
)
243+
except Exception as error:
244+
logging.error(f"Failed to convert PaddlePaddle model: {error}.")
245+
logging.error(traceback.print_exc())
246+
finally:
247+
if (
248+
os.environ.get("P2O_KEEP_TEMP_MODEL", "0").lower()
249+
not in [
250+
"1",
251+
"true",
252+
"on",
253+
]
254+
and PADDLE2ONNX_EXPORT_TEMP_DIR is not None
255+
):
256+
logging.warning(
257+
"Intermediate model and param files will be deleted,"
258+
+ " if you want to keep them, please set env variable `P2O_KEEP_TEMP_MODEL` to True."
259+
)
260+
shutil.rmtree(PADDLE2ONNX_EXPORT_TEMP_DIR, ignore_errors=True)
261+
PADDLE2ONNX_EXPORT_TEMP_DIR = None
190262

191-
deploy_backend = deploy_backend.lower()
192-
if custom_op_info is None:
193-
onnx_model_str = c_p2o.export(
194-
model_filename,
195-
params_filename,
196-
opset_version,
197-
auto_upgrade_opset,
198-
verbose,
199-
enable_onnx_checker,
200-
enable_experimental_op,
201-
enable_optimize,
202-
{},
203-
deploy_backend,
204-
calibration_file,
205-
external_file,
206-
export_fp16_model,
207-
)
208-
else:
209-
onnx_model_str = c_p2o.export(
210-
model_filename,
211-
params_filename,
212-
opset_version,
213-
auto_upgrade_opset,
214-
verbose,
215-
enable_onnx_checker,
216-
enable_experimental_op,
217-
enable_optimize,
218-
custom_op_info,
219-
deploy_backend,
220-
calibration_file,
221-
external_file,
222-
export_fp16_model,
223-
)
224263
if save_file is not None:
225264
if enable_polygraphy:
226265
try:
@@ -277,6 +316,7 @@ def dygraph2onnx(layer, save_file, input_spec=None, opset_version=9, **configs):
277316
"on",
278317
]:
279318
logging.warning(
280-
"Static PaddlePaddle model will be deleted, if you want to keep it, please set env variable `P2O_KEEP_TEMP_MODEL` to True."
319+
"Static PaddlePaddle model will be deleted, if you want to keep it,"
320+
+ " please set env variable `P2O_KEEP_TEMP_MODEL` to True."
281321
)
282322
shutil.rmtree(paddle_model_dir, ignore_errors=True)

paddle2onnx/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def check_model(onnx_model):
5454

5555

5656
levels = {0: "ERROR", 1: "WARNING", 2: "INFO", 3: "DEBUG"}
57+
level_color = {
58+
0: "\033[0;31m",
59+
1: "\033[0;33m",
60+
2: "\033[0;34m",
61+
3: "\033[0;32m",
62+
}
5763

5864

5965
class logging:
@@ -67,8 +73,8 @@ def log(level=2, message="", use_color=False):
6773
if logging.log_level >= level:
6874
if use_color:
6975
print(
70-
"\033[1;31;40m{} [{}]\t{}\033[0m".format(
71-
current_time, levels[level], message
76+
"{}{} [{}]\t{}\033[0m".format(
77+
level_color[level], current_time, levels[level], message
7278
)
7379
.encode("utf-8")
7480
.decode("latin1")

0 commit comments

Comments
 (0)