Skip to content

Commit 0efd605

Browse files
committed
Merge branch 'develop' into opt_paddle_unittest
2 parents 978fb48 + e1dee5b commit 0efd605

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

graph_net/sample_pass/agent_unittest_generator.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(
210210
self,
211211
framework: str,
212212
model_path: str,
213+
output_name: str,
213214
output_dir: str,
214215
device: Literal["auto", "cpu", "cuda"] = "auto",
215216
generate_main: bool = False,
@@ -219,6 +220,7 @@ def __init__(
219220
):
220221
self.framework = framework
221222
self.model_path = Path(model_path).resolve()
223+
self.output_name = output_name
222224
self.output_dir = Path(output_dir)
223225
self.device = self._choose_device(device)
224226
self.generate_main = generate_main
@@ -283,7 +285,7 @@ def _make_data_input_predicator(
283285
return lambda *args, **kwargs: True
284286

285287
def _write_to_file(self, unittest, output_dir):
286-
output_path = Path(output_dir) / f"{self.model_path.name}_test.py"
288+
output_path = Path(output_dir) / self.output_name
287289
output_path.parent.mkdir(parents=True, exist_ok=True)
288290
output_path.write_text(unittest, encoding="utf-8")
289291
print(
@@ -447,12 +449,25 @@ def declare_config(
447449
def __call__(self, rel_model_path: str):
448450
self.resumable_handle_sample(rel_model_path)
449451

452+
def sample_handled(self, rel_model_path: str) -> bool:
453+
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
454+
if not dst_model_path.exists():
455+
return False
456+
output_name = self._get_output_name(rel_model_path)
457+
num_model_py_files = len(list(dst_model_path.rglob(output_name)))
458+
assert num_model_py_files <= 1
459+
return num_model_py_files == 1
460+
461+
def _get_output_name(self, rel_model_path: str):
462+
return f"{Path(rel_model_path).name}_test.py"
463+
450464
def resume(self, rel_model_path: str):
451465
model_path_prefix = Path(self.config["model_path_prefix"])
452466
output_dir = Path(self.config["output_dir"])
453467
generator = AgentUnittestGenerator(
454468
framework=self.config["framework"],
455469
model_path=str(model_path_prefix / rel_model_path),
470+
output_name=self._get_output_name(rel_model_path),
456471
output_dir=str(output_dir / rel_model_path),
457472
device=self.config["device"],
458473
generate_main=self.config["generate_main"],

graph_net/torch/fx_graph_parse_util.py

100644100755
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def _rename_placeholder(name, pattern2replacement):
116116
if not (name[:2] == "L_" or name[:2] == "l_"):
117117
return name
118118
name = name[2:]
119-
if name[0] == "l":
120-
name = "L" + name[1:]
119+
if name[:2] == "l_":
120+
name = "L_" + name[2:]
121121
for pattern, replacement in pattern2replacement.items():
122122
name = name.replace(pattern, replacement)
123123
return name
@@ -161,7 +161,7 @@ def get_input_names_from_placeholder():
161161
if node.op != "placeholder":
162162
continue
163163
node.target = _rename_placeholder(node.target, pattern2replacement)
164-
node.name = _rename_placeholder(node.name, pattern2replacement)
164+
node.name = node.target
165165

166166
def get_diff_input_names():
167167
placeholder_names = set(get_input_names_from_placeholder())

0 commit comments

Comments
 (0)