Skip to content

Commit e7deedc

Browse files
committed
Add the definition of ModelRecord to refactor some codes.
1 parent 7cfd4eb commit e7deedc

File tree

1 file changed

+133
-66
lines changed

1 file changed

+133
-66
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 133 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from graph_net.analysis_util import get_incorrect_models
1313
from graph_net import path_utils
1414

15-
kMaxGraphSize = 4096
15+
MAX_GRAPH_SIZE = 4096
1616

1717

1818
def convert_b64_string_to_json(b64str):
@@ -109,9 +109,30 @@ def _print(self):
109109
print()
110110

111111

112+
@dataclass
113+
class ModelRecord:
114+
original_path: str
115+
uniform_split_positions: List[int] = field(default_factory=list)
116+
subgraph_paths: List[str] = field(default_factory=list)
117+
incorrect_subgraph_idxs: List[int] = field(default_factory=list)
118+
119+
def get_split_positions(self, decompose_method):
120+
if decompose_method == "fixed-start":
121+
assert (
122+
len(self.uniform_split_positions) >= 2
123+
), f"{self.uniform_split_positions=}"
124+
return [0, self.uniform_split_positions[1]]
125+
return self.uniform_split_positions
126+
127+
def update_for_next_decompose(self, subgraph_idx, max_subgraph_size):
128+
self.uniform_split_positions = reconstruct_split_positions_for_subgraphs(
129+
self.uniform_split_positions, subgraph_idx, max_subgraph_size
130+
)
131+
132+
112133
@dataclass
113134
class DecomposeConfig:
114-
method: str
135+
decompose_method: str
115136
tolerance: int | List[int]
116137
max_subgraph_size: int = -1
117138
tasks_map: Dict[str, Union[int, str, list, dict]] = field(default_factory=dict)
@@ -145,18 +166,28 @@ def get_incorrect_models(self, pass_id):
145166
assert pass_key in self.running_states
146167
return self.running_states[pass_key]["incorrect_models"]
147168

148-
def update_running_states(self, pass_id, **kwargs):
149-
pass_key = get_pass_name(pass_id)
150-
if self.running_states.get(pass_key, None) is None:
169+
def update_running_states(self, pass_id, incorrect_models, model_name2record):
170+
assert pass_id == "initial" or isinstance(pass_id, int)
171+
pass_key = get_pass_name(pass_id) if isinstance(pass_id, int) else pass_id
172+
if pass_key not in self.running_states:
151173
self.running_states[pass_key] = {}
152174

153-
for key, value in kwargs.items():
154-
assert key in [
155-
"num_incorrect_models",
156-
"incorrect_models",
157-
"failed_decomposition_models",
158-
]
159-
self.running_states[pass_key][key] = value
175+
self.running_states[pass_key]["incorrect_models_from_log"] = list(
176+
sorted(incorrect_models)
177+
)
178+
if model_name2record:
179+
target_model_names = list(model_name2record.keys())
180+
model_name2subgraph_idxs = collect_incorrect_subgraph_idxs(
181+
self.decompose_method,
182+
target_model_names,
183+
incorrect_models,
184+
model_name2record,
185+
)
186+
for model_name, model_record in sorted(model_name2record.items()):
187+
model_record.incorrect_subgraph_idxs = model_name2subgraph_idxs[
188+
model_name
189+
]
190+
self.running_states[pass_key][model_name] = model_record.__dict__
160191

161192

162193
def get_rectfied_model_path(model_path):
@@ -226,21 +257,18 @@ def run_decomposer_for_single_model(
226257

227258

228259
def run_decomposer_for_multi_models(
229-
framework, tasks_map, decomposed_samples_dir, max_subgraph_size, log_path
260+
framework, model_name2record, decomposed_samples_dir, max_subgraph_size, log_path
230261
):
231-
failed_decomposition = []
262+
failed_decomposition_models = []
232263

233264
print(
234265
f"[Decomposition] max_subgraph_size: {max_subgraph_size}, log_path: {log_path}"
235266
)
236-
for model_name, task_info in tasks_map.items():
237-
original_path = task_info["original_path"]
238-
split_positions = sorted(list(task_info["split_positions"]))
239-
240-
method = "fixed-start"
241-
if method == "fixed-start":
242-
assert len(split_positions) >= 3, f"{split_positions=}"
243-
split_positions = [0, split_positions[1]]
267+
for model_name, model_record in model_name2record.items():
268+
original_path = model_record.original_path
269+
split_positions = model_record.get_split_positions(
270+
decompose_method="fixed-start"
271+
)
244272

245273
rectified_model_path = get_rectfied_model_path(original_path)
246274
assert os.path.exists(
@@ -255,8 +283,8 @@ def run_decomposer_for_multi_models(
255283
log_path,
256284
)
257285
if not success:
258-
failed_decomposition.append(rectified_model_path)
259-
return tasks_map, failed_decomposition
286+
failed_decomposition_models.append(rectified_model_path)
287+
return failed_decomposition_models
260288

261289

262290
def run_evaluation(
@@ -314,10 +342,13 @@ def generate_initial_tasks(args):
314342
initial_failures = get_ranged_incorrect_models(args.tolerance, args.log_file)
315343

316344
tasks_map = {}
317-
max_subgraph_size = min(args.max_subgraph_size, kMaxGraphSize // 2)
345+
if args.decompose_method == "fixed-start":
346+
max_subgraph_size = MAX_GRAPH_SIZE
347+
else:
348+
max_subgraph_size = min(args.max_subgraph_size, MAX_GRAPH_SIZE)
318349

319350
initial_split_positions = reconstruct_split_positions_for_subgraphs(
320-
[0, kMaxGraphSize], 0, max_subgraph_size
351+
[0, MAX_GRAPH_SIZE], 0, max_subgraph_size
321352
)
322353
for model_path in initial_failures:
323354
model_name = get_model_name_with_subgraph_tag(model_path)
@@ -327,7 +358,7 @@ def generate_initial_tasks(args):
327358
}
328359

329360
running_states = {
330-
"pass_0": {
361+
"initial": {
331362
"num_incorrect_models": len(initial_failures),
332363
"incorrect_models": list(sorted(initial_failures)),
333364
}
@@ -343,7 +374,9 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
343374
return model_name, subgraph_idx
344375

345376

346-
def collect_incorrect_subgraph_idxs(args, target_model_names, incorrect_models):
377+
def collect_incorrect_subgraph_idxs(
378+
decompose_method, target_model_names, incorrect_models, model_name2record
379+
):
347380
model_name2subgraph_idxs = {}
348381
for subgraph_path in sorted(incorrect_models):
349382
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
@@ -355,11 +388,17 @@ def collect_incorrect_subgraph_idxs(args, target_model_names, incorrect_models):
355388
model_name2subgraph_idxs[model_name] = []
356389
model_name2subgraph_idxs[model_name].append(subgraph_idx)
357390

358-
if args.method == "fixed-start":
391+
if decompose_method == "fixed-start":
359392
print(model_name2subgraph_idxs)
360393
for model_name in target_model_names:
361394
if model_name not in model_name2subgraph_idxs:
362-
model_name2subgraph_idxs[model_name] = [1]
395+
if (
396+
model_name2record
397+
and len(model_name2record[model_name].uniform_split_positions) > 2
398+
):
399+
model_name2subgraph_idxs[model_name] = [1]
400+
else:
401+
model_name2subgraph_idxs[model_name] = []
363402
else:
364403
assert len(
365404
model_name2subgraph_idxs[model_name]
@@ -375,15 +414,15 @@ def generate_successor_tasks(args, base_output_dir, current_pass_id):
375414
prev_config = DecomposeConfig.load(prev_pass_dir)
376415
max_subgraph_size = prev_config.max_subgraph_size // 2
377416
incorrect_models = prev_config.get_incorrect_models(current_pass_id)
378-
if args.method != "fixed-start" and not incorrect_models:
417+
if args.decompose_method != "fixed-start" and not incorrect_models:
379418
return {}, max_subgraph_size, prev_config.running_states
380419

381420
tasks_map = {}
382421
prev_tasks_map = prev_config.tasks_map
383422

384423
target_model_names = list(prev_tasks_map.keys())
385424
model_name2subgraph_idxs = collect_incorrect_subgraph_idxs(
386-
args, target_model_names, incorrect_models
425+
args.decompose_method, target_model_names, incorrect_models, None
387426
)
388427

389428
for model_name, subgraph_idxs in model_name2subgraph_idxs.items():
@@ -393,6 +432,8 @@ def generate_successor_tasks(args, base_output_dir, current_pass_id):
393432
split_positions = reconstruct_split_positions_for_subgraphs(
394433
prev_split_positions, subgraph_idxs, max_subgraph_size
395434
)
435+
if args.decompose_method == "fixed-start" and len(split_positions) > 3:
436+
split_positions = split_positions[0:3]
396437

397438
tasks_map[model_name] = {
398439
"original_path": pre_task_for_model["original_path"],
@@ -430,58 +471,76 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
430471
return tasks_map, max_subgraph_size, running_states
431472

432473

433-
def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspace):
474+
def collect_decomposed_subgraphs(model_name2record, decomposed_samples_dir):
475+
for root, dirs, files in os.walk(decomposed_samples_dir):
476+
if path_utils.is_single_model_dir(root):
477+
model_name, _ = extract_model_name_and_subgraph_idx(root)
478+
assert model_name in model_name2record
479+
model_record = model_name2record[model_name]
480+
model_record.subgraph_paths.append(root)
481+
return model_name2record
482+
483+
484+
def execute_decomposition_phase(
485+
max_subgraph_size, model_name2record, framework, workspace
486+
):
434487
"""Executes the decomposition phase."""
435488

436-
failed_decomposition = []
437-
need_decompose = True if len(tasks_map) > 0 else False
438-
method = "fixed-start"
489+
failed_decomposition_models = []
490+
need_decompose = True if len(model_name2record) > 0 else False
491+
decompose_method = "fixed-start"
492+
decomposed_samples_dir = os.path.join(
493+
workspace, "samples" if framework == "torch" else "paddle_samples"
494+
)
439495

440496
while need_decompose:
441-
decomposed_samples_dir = os.path.join(
442-
workspace, "samples" if framework == "torch" else "paddle_samples"
443-
)
444497
if not os.path.exists(decomposed_samples_dir):
445498
os.makedirs(decomposed_samples_dir, exist_ok=True)
446499
print(f"[Decomposition] decomposed_samples_dir: {decomposed_samples_dir}")
447500

448501
log_path = os.path.join(
449502
workspace, f"log_decompose-max_subgraph_size_{max_subgraph_size}.txt"
450503
)
451-
tasks_map, failed_decomposition = run_decomposer_for_multi_models(
452-
framework, tasks_map, decomposed_samples_dir, max_subgraph_size, log_path
504+
failed_decomposition_models = run_decomposer_for_multi_models(
505+
framework,
506+
model_name2record,
507+
decomposed_samples_dir,
508+
max_subgraph_size,
509+
log_path,
453510
)
454511
num_decomposed_samples = count_samples(decomposed_samples_dir)
455512
print(
456-
f"[Decomposition] number of graphs: {len(tasks_map)} -> {num_decomposed_samples}",
513+
f"[Decomposition] number of graphs: {len(model_name2record)} -> {num_decomposed_samples}",
457514
flush=True,
458515
)
459516
if (
460-
not failed_decomposition
461-
and num_decomposed_samples == len(tasks_map)
517+
not failed_decomposition_models
518+
and num_decomposed_samples == len(model_name2record)
462519
and max_subgraph_size > 1
463-
and method != "fixed-start"
520+
and decompose_method != "fixed-start"
464521
):
465522
need_decompose = True
466523
shutil.rmtree(decomposed_samples_dir)
467524
os.makedirs(decomposed_samples_dir, exist_ok=True)
468525
max_subgraph_size = max(1, max_subgraph_size // 2)
469-
for model_name, task_info in tasks_map.items():
470-
split_positions = task_info["split_positions"]
471-
if not split_positions or len(split_positions) < 2:
526+
for model_name, model_record in model_name2record.items():
527+
if (
528+
not model_record.uniform_split_positions
529+
or len(model_record.uniform_split_positions) < 2
530+
):
472531
continue
473-
new_split_positions = reconstruct_split_positions_for_subgraphs(
474-
split_positions, 0, max_subgraph_size
475-
)
476-
task_info["split_positions"] = new_split_positions
532+
model_record.update_for_next_decompose(0, max_subgraph_size)
477533
else:
478534
need_decompose = False
479535
print()
480536

481-
if failed_decomposition:
482-
print(f"[WARN] {len(failed_decomposition)} models failed to decompose.")
537+
if failed_decomposition_models:
538+
print(f"[WARN] {len(failed_decomposition_models)} models failed to decompose.")
483539

484-
return tasks_map, failed_decomposition, max_subgraph_size
540+
model_name2record = collect_decomposed_subgraphs(
541+
model_name2record, decomposed_samples_dir
542+
)
543+
return model_name2record, max_subgraph_size
485544

486545

487546
def count_unique_original_models(incorrect_models):
@@ -518,8 +577,16 @@ def main(args):
518577
tasks_map, max_subgraph_size, running_states = prepare_tasks_and_verify(
519578
args, current_pass_id, base_output_dir
520579
)
580+
581+
model_name2record = {}
582+
for model_name in tasks_map.keys():
583+
model_name2record[model_name] = ModelRecord(
584+
original_path=tasks_map[model_name]["original_path"],
585+
uniform_split_positions=tasks_map[model_name]["split_positions"],
586+
)
587+
521588
decompose_config = DecomposeConfig(
522-
method=args.method,
589+
decompose_method=args.decompose_method,
523590
tolerance=args.tolerance,
524591
max_subgraph_size=max_subgraph_size,
525592
tasks_map=tasks_map,
@@ -533,14 +600,10 @@ def main(args):
533600
if task_controller.task_scheduler["run_decomposer"]:
534601
print("\n--- Phase 1: Decomposition ---", flush=True)
535602
(
536-
tasks_map,
537-
failed_decomposition,
603+
model_name2record,
538604
max_subgraph_size,
539605
) = execute_decomposition_phase(
540-
max_subgraph_size, tasks_map, args.framework, work_dir
541-
)
542-
decompose_config.update_running_states(
543-
current_pass_id, failed_decomposition_models=list(failed_decomposition)
606+
max_subgraph_size, model_name2record, args.framework, work_dir
544607
)
545608
else:
546609
print("\n--- Phase 1: Decomposition (skipped) ---", flush=True)
@@ -560,22 +623,26 @@ def main(args):
560623
print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---")
561624
next_pass_incorrect_models = sorted(get_incorrect_models(tolerance, log_path))
562625
num_original_models = count_unique_original_models(next_pass_incorrect_models)
626+
563627
decompose_config.update_running_states(
564-
current_pass_id + 1,
565-
num_incorrect_models=num_original_models,
566-
incorrect_models=list(next_pass_incorrect_models),
628+
current_pass_id,
629+
next_pass_incorrect_models,
630+
model_name2record,
567631
)
568632

569633
print(
570634
f"[Analysis] Found {len(next_pass_incorrect_models)} incorrect subgraphs ({num_original_models} original models)."
571635
)
572636
for idx, model_path in enumerate(next_pass_incorrect_models):
573637
print(f"- [{idx}] {model_path}")
638+
574639
print_summary_and_suggestion(
575640
args, next_pass_incorrect_models, max_subgraph_size
576641
)
577642

578643
# --- Step 5: Save States ---
644+
for model_name, model_record in model_name2record.items():
645+
print(f"- {model_name}: {model_record}")
579646
decompose_config.save(work_dir)
580647

581648

@@ -587,7 +654,7 @@ def main(args):
587654
parser.add_argument(
588655
"--test-config", type=str, required=True, help="Base64 encoded test config"
589656
)
590-
parser.add_argument("--method", type=str, required=True)
657+
parser.add_argument("--decompose-method", type=str, required=True)
591658
parser.add_argument(
592659
"--tolerance",
593660
type=int,

0 commit comments

Comments
 (0)