Skip to content

Commit bd84a36

Browse files
committed
Fix model_name in decompose config.
1 parent 3df1072 commit bd84a36

File tree

2 files changed

+44
-64
lines changed

2 files changed

+44
-64
lines changed

graph_net/paddle/naive_graph_decomposer.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,29 +82,14 @@ def do_extract(self, **input_dict):
8282
model_path = os.path.join(
8383
self.builtin_extractor.workspace_path, self.builtin_extractor.name
8484
)
85-
for (
86-
subgraph_idx,
87-
samples,
88-
) in self.builtin_extractor.subgraph_idx2samples.items():
89-
for seq_idx in range(len(samples)):
90-
if (
91-
self.builtin_extractor.num_samples_of_all_subgraphs == 1
92-
and len(samples) == 1
93-
):
94-
subgraph_path = model_path
95-
elif len(samples) == 1:
96-
subgraph_path = os.path.join(model_path, f"subgraph_{subgraph_idx}")
97-
else:
98-
subgraph_path = os.path.join(
99-
model_path, f"subgraph_{subgraph_idx}_{seq_idx}"
100-
)
101-
self.subgraph_path_list.append(subgraph_path)
102-
self.builtin_extractor.write_sample_to_file(
103-
subgraph_path, samples[seq_idx]
104-
)
105-
print(
106-
f"Graph and tensors for '{self.builtin_extractor.name}' extracted successfully to: {model_path}"
107-
)
85+
assert len(self.builtin_extractor.subgraph_idx2samples) == 1
86+
87+
samples = self.builtin_extractor.subgraph_idx2samples[0]
88+
for seq_idx in range(len(samples)):
89+
subgraph_path = f"{model_path}_{seq_idx}"
90+
self.subgraph_path_list.append(subgraph_path)
91+
self.builtin_extractor.write_sample_to_file(subgraph_path, samples[seq_idx])
92+
print(f"Save to {subgraph_path}")
10893
return static_model
10994

11095
def __call__(self, **input_dict):

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,18 @@ def load_decompose_config(work_dir: str) -> Dict[str, Any]:
117117
def save_decompose_config(
118118
work_dir: str,
119119
max_subgraph_size: int,
120+
tasks_map: Dict[str, Union[int, str, list, dict]],
120121
incorrect_paths: Union[List[str], Set[str]],
121-
active_models_map: Dict[str, str],
122-
split_positions_map: Dict[str, List[int]],
123122
failed_decomposition_models: Union[List[str], Set[str]],
124123
):
125124
"""Saves the current state to a JSON file."""
125+
126+
active_models_map = {}
127+
split_positions_map = {}
128+
for model_name, task_info in tasks_map.items():
129+
active_models_map[model_name] = task_info["original_path"]
130+
split_positions_map[model_name] = task_info["split_positions"]
131+
126132
config = {
127133
"max_subgraph_size": max_subgraph_size,
128134
"incorrect_models": list(incorrect_paths),
@@ -143,7 +149,7 @@ def get_model_name_with_subgraph_tag(model_path):
143149
return f"{fields[-2]}_{fields[-1]}" if re.match(pattern, fields[-1]) else fields[-1]
144150

145151

146-
def run_decomposer(
152+
def run_naive_decomposer(
147153
framework: str,
148154
model_path: str,
149155
output_dir: str,
@@ -170,8 +176,8 @@ def run_decomposer(
170176
json.dumps(decorator_config).encode()
171177
).decode()
172178

173-
print(f"[Decomposing] model_path: {model_path}")
174-
print(f"[Decomposing] split_positions: {split_positions}")
179+
print(f"[Decomposition] model_path: {model_path}")
180+
print(f"[Decomposition] split_positions: {split_positions}")
175181

176182
cmd = [
177183
sys.executable,
@@ -185,13 +191,13 @@ def run_decomposer(
185191
result = subprocess.run(
186192
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
187193
)
194+
# print(result.stdout)
188195
if result.returncode != 0:
189196
print(
190197
f"[ERROR] Decomposition failed for {model_path}\n{result.stderr}",
191198
flush=True,
192199
)
193200
return False
194-
# print(result.stdout)
195201
return True
196202

197203

@@ -215,8 +221,8 @@ def run_evaluation(
215221
for item in (f"--{key}", str(value))
216222
]
217223

218-
print(f"[Batch Testing] Logging to: {log_path}")
219-
print(f"[Command] {' '.join(cmd)}")
224+
print(f"[Evaluation] Logging to: {log_path}")
225+
print(f"[Evaluation] command: {' '.join(cmd)}")
220226

221227
os.makedirs(os.path.dirname(log_path), exist_ok=True)
222228
with open(log_path, "w") as f:
@@ -257,19 +263,18 @@ def generate_initial_tasks(args):
257263
initial_failures = get_incorrect_models(args.tolerance, args.log_file)
258264

259265
tasks_map = {}
260-
active_models_map_for_save = {}
261266

262267
for model_path in initial_failures:
263-
model_name = os.path.basename(model_path.rstrip("/"))
264-
active_models_map_for_save[model_name] = model_path
268+
model_name = get_model_name_with_subgraph_tag(model_path)
265269
tasks_map[model_name] = {
266270
"subgraph_path": model_path,
267271
"original_path": model_path,
268272
"subgraph_size": [0, kMaxGraphSize],
273+
"split_positions": set(),
269274
}
270275

271276
max_subgraph_size = args.max_subgraph_size
272-
return tasks_map, active_models_map_for_save, max_subgraph_size
277+
return tasks_map, max_subgraph_size
273278

274279

275280
def generate_refined_tasks(base_output_dir, current_pass_id):
@@ -286,22 +291,20 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
286291
prev_max_subgraph_size = prev_config.get("max_subgraph_size")
287292
max_subgraph_size = prev_max_subgraph_size // 2
288293

289-
if not prev_incorrect_subgraphs or prev_max_subgraph_size <= 1:
290-
return {}, {}, max_subgraph_size
294+
if not prev_incorrect_subgraphs:
295+
return {}, max_subgraph_size
291296

292297
print("[Analysis] Refining splits based on previous incorrect models ...")
293298

294299
tasks_map = {}
295-
active_models_map_for_save = {}
296-
297300
for subgraph_path in prev_incorrect_subgraphs:
298301
# Parse model name and subgraph index
299302
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
300303
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
301304
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
302-
305+
print(f"subgraph_path: {subgraph_path}")
306+
print(f"model_name: {model_name}, subgraph_idx: {subgraph_idx}")
303307
assert model_name in prev_active_models_map
304-
active_models_map_for_save[model_name] = prev_active_models_map[model_name]
305308

306309
# Reconstruct previous subgraph size to locate the failing segment
307310
prev_split_positions = prev_split_positions_map.get(model_name, [])
@@ -315,15 +318,15 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
315318
"subgraph_path": subgraph_path,
316319
"original_path": prev_active_models_map[model_name],
317320
"subgraph_size": subgraph_size[subgraph_idx],
321+
"split_positions": set(),
318322
}
319323

320-
return tasks_map, active_models_map_for_save, max_subgraph_size
324+
return tasks_map, max_subgraph_size
321325

322326

323327
def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_work_dir):
324328
"""Executes the decomposition phase (Phase 1)."""
325329
failed_decomposition = []
326-
final_used_splits_map = {}
327330

328331
need_decompose = True if len(tasks_map) > 0 else False
329332
if need_decompose:
@@ -338,27 +341,27 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
338341
print(f"decomposed_samples_dir: {decomposed_samples_dir}")
339342

340343
for model_name, task_info in tasks_map.items():
341-
print(f"[Decomposing] max_subgraph_size: {max_subgraph_size}")
344+
print(f"[Decomposition] max_subgraph_size: {max_subgraph_size}")
342345
original_path = task_info["original_path"]
343346
split_positions = calculate_split_positions_for_subgraph(
344347
task_info["subgraph_size"], max_subgraph_size
345348
)
346-
final_used_splits_map[model_name] = split_positions
349+
task_info["split_positions"] = split_positions
347350

348351
rectified_model_path = get_rectfied_model_path(original_path)
349352
assert os.path.exists(
350353
rectified_model_path
351354
), f"{rectified_model_path} does not exist."
352355

353-
success = run_decomposer(
356+
success = run_naive_decomposer(
354357
framework, rectified_model_path, decomposed_samples_dir, split_positions
355358
)
356359
if not success:
357360
failed_decomposition.append(rectified_model_path)
358361

359362
num_decomposed_samples = count_samples(decomposed_samples_dir)
360363
print(
361-
f"[Decomposing] number of graphs: {len(tasks_map)} -> {num_decomposed_samples}",
364+
f"[Decomposition] number of graphs: {len(tasks_map)} -> {num_decomposed_samples}",
362365
flush=True,
363366
)
364367
if (
@@ -381,7 +384,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
381384
if failed_decomposition:
382385
print(f"[WARN] {len(failed_decomposition)} models failed to decompose.")
383386

384-
return failed_decomposition, final_used_splits_map, max_subgraph_size
387+
return tasks_map, failed_decomposition, max_subgraph_size
385388

386389

387390
def print_final_summary(next_round_models, max_subgraph_size):
@@ -408,17 +411,11 @@ def main(args):
408411

409412
# --- Step 1: Prepare Tasks ---
410413
if current_pass_id == 0:
411-
(
412-
tasks_map,
413-
active_models_map_for_save,
414-
max_subgraph_size,
415-
) = generate_initial_tasks(args)
414+
tasks_map, max_subgraph_size = generate_initial_tasks(args)
416415
else:
417-
(
418-
tasks_map,
419-
active_models_map_for_save,
420-
max_subgraph_size,
421-
) = generate_refined_tasks(base_output_dir, current_pass_id)
416+
tasks_map, max_subgraph_size = generate_refined_tasks(
417+
base_output_dir, current_pass_id
418+
)
422419

423420
print(f"[INFO] initial max_subgraph_size: {max_subgraph_size}")
424421
print(f"[INFO] number of incorrect models: {len(tasks_map)}")
@@ -442,11 +439,10 @@ def main(args):
442439

443440
# --- Step 3: Decomposition ---
444441
failed_decomposition = []
445-
final_used_splits_map = {}
446442
if task_controller.task_scheduler["run_decomposer"]:
447443
(
444+
tasks_map,
448445
failed_decomposition,
449-
final_used_splits_map,
450446
max_subgraph_size,
451447
) = execute_decomposition_phase(
452448
max_subgraph_size, tasks_map, args.framework, pass_work_dir
@@ -455,23 +451,22 @@ def main(args):
455451
# --- Step 4: Testing ---
456452
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
457453
if task_controller.task_scheduler["run_evaluation"]:
458-
print("\n--- Phase 2: Batch Testing ---")
454+
print("\n--- Phase 2: Evaluation ---")
459455
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
460456

461457
# --- Step 5: Analysis ---
462458
next_round_models = set()
463459
if task_controller.task_scheduler["post_analysis"]:
464460
print("\n--- Phase 3: Analysis ---")
465461
next_round_models = get_incorrect_models(args.tolerance, pass_log_path)
466-
print(f"[Result] Found {len(next_round_models)} incorrect subgraphs.")
462+
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
467463

468464
# --- Step 6: Save State ---
469465
save_decompose_config(
470466
pass_work_dir,
471467
max_subgraph_size,
468+
tasks_map,
472469
next_round_models,
473-
active_models_map_for_save,
474-
final_used_splits_map,
475470
failed_decomposition,
476471
)
477472

0 commit comments

Comments
 (0)