Skip to content

Commit 946de52

Browse files
committed
Fix model_name and refine some codes.
1 parent becb25c commit 946de52

File tree

1 file changed

+91
-76
lines changed

1 file changed

+91
-76
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 91 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
import re
34
import json
45
import base64
56
import shutil
@@ -9,7 +10,9 @@
910
from typing import List, Set, Dict, Any, Union
1011
import graph_net
1112
from graph_net.analysis_util import get_incorrect_models
12-
from graph_net import path_utils, test_compiler_util
13+
from graph_net import path_utils
14+
15+
kMaxGraphSize = 4096
1316

1417

1518
def convert_b64_string_to_json(b64str):
@@ -22,16 +25,16 @@ def __init__(self, args):
2225
self.test_config = convert_b64_string_to_json(args.test_config)
2326
assert "test_module_name" in self.test_config
2427

25-
test_module_name = self.test_config["test_module_name"]
28+
self.test_module_name = self.test_config["test_module_name"]
2629
max_pass_id = self._determine_max_pass_id(self.root_output_dir)
2730
self.current_pass_id = (
28-
max_pass_id if test_module_name == "test_target_device" else max_pass_id + 1
29-
)
30-
print(
31-
f"test_module_name: {test_module_name}, current_pass_id: {self.current_pass_id}"
31+
max_pass_id
32+
if self.test_module_name == "test_target_device"
33+
else max_pass_id + 1
3234
)
3335

34-
self._init_task_scheduler(test_module_name)
36+
self._init_task_scheduler(self.test_module_name)
37+
self._print()
3538

3639
def _determine_max_pass_id(self, output_dir: str) -> int:
3740
"""Scans the output directory to determine the next pass ID."""
@@ -71,6 +74,14 @@ def _init_task_scheduler(self, test_module_name):
7174
"post_analysis": True,
7275
}
7376

77+
def _print(self):
78+
print(
79+
f"[TaskController] test_module_name: {self.test_module_name}, current_pass_id: {self.current_pass_id}",
80+
flush=True,
81+
)
82+
print(f"[TaskController] task_scheduler: {self.task_scheduler}", flush=True)
83+
print()
84+
7485

7586
def get_rectfied_model_path(model_path):
7687
graphnet_root = path_utils.get_graphnet_root()
@@ -90,10 +101,13 @@ def get_decompose_config_path(output_dir: str) -> str:
90101
return os.path.join(output_dir, "decompose_config.json")
91102

92103

93-
def load_decompose_config(pass_id: int, output_dir: str) -> Dict[str, Any]:
104+
def get_decompose_workspace_path(output_dir, pass_id):
105+
return os.path.join(output_dir, f"pass_{pass_id}")
106+
107+
108+
def load_decompose_config(work_dir: str) -> Dict[str, Any]:
94109
"""Loads the configuration file from the previous pass."""
95-
prev_dir = os.path.join(output_dir, f"pass_{pass_id - 1}")
96-
config_path = get_decompose_config_path(prev_dir)
110+
config_path = get_decompose_config_path(work_dir)
97111

98112
if not os.path.exists(config_path):
99113
raise FileNotFoundError(f"Missing configuration file: {config_path}")
@@ -125,9 +139,9 @@ def save_decompose_config(
125139

126140

127141
def get_model_name_with_subgraph_tag(model_path):
128-
model_name = test_compiler_util.get_model_name(model_path)
129-
subgraph_tag = test_compiler_util.get_subgraph_tag(model_path)
130-
return f"{model_name}_{subgraph_tag}" if subgraph_tag else model_name
142+
fields = model_path.rstrip("/").split(os.sep)
143+
pattern = rf"^subgraph(_\d+)?$"
144+
return f"{fields[-2]}_{fields[-1]}" if re.match(pattern, fields[-1]) else fields[-1]
131145

132146

133147
def run_decomposer(
@@ -256,6 +270,34 @@ def calculate_current_subgraph_size(
256270
)
257271

258272

273+
def calculate_split_postions_for_subgraph(subgraph_size):
274+
assert isinstance(subgraph_size, (list, tuple)) and len(subgraph_size) == 2
275+
276+
# Get the specific failing subgraph size [Start, End]
277+
fail_start, fail_end = subgraph_size
278+
279+
# though intervals logic usually handles this via float('inf') replacement if used.
280+
if fail_end == float("inf"):
281+
fail_end = kMaxGraphSize
282+
283+
# Dynamic step calculation
284+
subgraph_size_len = fail_end - fail_start
285+
new_step = subgraph_size_len // 2
286+
287+
if new_step < 1:
288+
new_step = subgraph_size_len
289+
290+
# Calculate Midpoint
291+
mid_point = fail_start + new_step
292+
293+
# Add split positions
294+
if mid_point > fail_start and mid_point < fail_end:
295+
split_positions = [int(fail_start), int(mid_point), int(fail_end)]
296+
else:
297+
split_positions = [int(fail_start), int(fail_end)]
298+
return split_positions
299+
300+
259301
def main(args):
260302
task_controller = TaskController(args)
261303
base_output_dir = task_controller.root_output_dir
@@ -267,7 +309,6 @@ def main(args):
267309

268310
tasks_map = {}
269311
active_models_map_for_save = {}
270-
kMaxGraphSize = 4096
271312

272313
# Initialize using the argument passed from bash
273314
max_subgraph_size = args.max_subgraph_size
@@ -287,14 +328,15 @@ def main(args):
287328
"split_positions": set(initial_splits),
288329
}
289330
else:
290-
prev_pass_dir = os.path.join(base_output_dir, f"pass_{current_pass_id - 1}")
331+
prev_pass_dir = get_decompose_workspace_path(
332+
base_output_dir, current_pass_id - 1
333+
)
291334
print(
292-
f"[Init] Resuming from Pass {current_pass_id - 1} (Dir: {prev_pass_dir})..."
335+
f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})..."
293336
)
294337

295-
prev_config = load_decompose_config(current_pass_id, base_output_dir)
296-
prev_map = prev_config.get("active_models_map", {})
297-
338+
prev_config = load_decompose_config(prev_pass_dir)
339+
prev_active_models_map = prev_config.get("active_models_map", {})
298340
prev_used_splits = prev_config.get("split_positions_map", {})
299341
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
300342

@@ -306,67 +348,37 @@ def main(args):
306348
print(f"[FINISHED] Debugging completed.")
307349
sys.exit(0)
308350

309-
print(f"[Analysis] Refining splits based on failures...")
310-
311-
for sub_path in prev_incorrect_subgraphs:
312-
parts = sub_path.rstrip("/").split("/")
313-
if len(parts) < 2:
314-
continue
315-
316-
subgraph_dirname = parts[-1]
317-
model_name = parts[-2]
318-
319-
if model_name in prev_map:
320-
active_models_map_for_save[model_name] = prev_map[model_name]
321-
if model_name not in tasks_map:
322-
tasks_map[model_name] = {
323-
"original_path": prev_map[model_name],
324-
"split_positions": set(),
325-
}
326-
else:
327-
continue
328-
329-
try:
330-
sub_idx = int(subgraph_dirname.split("_")[-1])
331-
except ValueError:
332-
continue
333-
334-
# 1. Reconstruct previous subgraph size to locate the failing segment
335-
old_split_position = sorted(prev_used_splits.get(model_name, []))
336-
subgraph_size = reconstruct_subgraph_size(old_split_position)
337-
338-
if sub_idx >= len(subgraph_size):
339-
print(
340-
f"[WARN] Index {sub_idx} out of bounds for {model_name} (old split position: {old_split_position})"
341-
)
342-
continue
351+
print(f"[Analysis] Refining splits based on previous incorrect models ...")
343352

344-
# 2. Get the specific failing subgraph size [Start, End]
345-
fail_start, fail_end = subgraph_size[sub_idx]
353+
for subgraph_path in prev_incorrect_subgraphs:
354+
print(f"- subgraph_path: {subgraph_path}")
355+
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
356+
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
357+
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
358+
print(f"- model_name: {model_name}, subgraph_idx: {subgraph_idx}")
346359

347-
# though intervals logic usually handles this via float('inf') replacement if used.
348-
if fail_end == float("inf"):
349-
fail_end = kMaxGraphSize
360+
assert model_name in prev_active_models_map
361+
active_models_map_for_save[model_name] = prev_active_models_map[model_name]
350362

351-
# Dynamic step calculation
352-
subgraph_size_len = fail_end - fail_start
353-
new_step = subgraph_size_len // 2
363+
# Reconstruct previous subgraph size to locate the failing segment
364+
prev_split_positions = sorted(prev_used_splits.get(model_name, []))
365+
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
366+
assert subgraph_idx < len(
367+
subgraph_size
368+
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
354369

355-
if new_step < 1:
356-
new_step = subgraph_size_len
357-
358-
# 3. Calculate Midpoint
359-
mid_point = fail_start + new_step
360-
361-
# 4. Add split positions
362-
if mid_point > fail_start and mid_point < fail_end:
363-
tasks_map[model_name]["split_positions"].update(
364-
[int(fail_start), int(mid_point), int(fail_end)]
365-
)
370+
split_postions = calculate_split_postions_for_subgraph(
371+
subgraph_size[subgraph_idx]
372+
)
373+
if model_name not in tasks_map:
374+
tasks_map[model_name] = {
375+
"subgraph_path": subgraph_path,
376+
"original_path": prev_active_models_map[model_name],
377+
"subgraph_size": subgraph_size[subgraph_idx],
378+
"split_positions": split_postions,
379+
}
366380
else:
367-
tasks_map[model_name]["split_positions"].update(
368-
[int(fail_start), int(fail_end)]
369-
)
381+
continue
370382

371383
# Recalculate based on current map to ensure log accuracy
372384
real_subgraph_size = calculate_current_subgraph_size(tasks_map, max_subgraph_size)
@@ -381,7 +393,7 @@ def main(args):
381393
sys.exit(0)
382394

383395
# --- Step 2: Prepare Workspace ---
384-
pass_work_dir = os.path.join(base_output_dir, f"pass_{current_pass_id}")
396+
pass_work_dir = get_decompose_workspace_path(base_output_dir, current_pass_id)
385397
if not os.path.exists(pass_work_dir):
386398
os.makedirs(pass_work_dir, exist_ok=True)
387399

@@ -401,6 +413,7 @@ def main(args):
401413
pass_work_dir, "samples" if args.framework == "torch" else "paddle_samples"
402414
)
403415
os.makedirs(decomposed_samples_dir, exist_ok=True)
416+
print(f"decomposed_samples_dir: {decomposed_samples_dir}")
404417

405418
for model_name, task_info in tasks_map.items():
406419
original_path = task_info["original_path"]
@@ -409,6 +422,8 @@ def main(args):
409422
final_used_splits_map[model_name] = split_positions
410423

411424
rectied_model_path = get_rectfied_model_path(original_path)
425+
print(f"original_path: {original_path}")
426+
print(f"rectied_model_path: {rectied_model_path}")
412427
assert os.path.exists(
413428
rectied_model_path
414429
), f"{rectied_model_path} does not exist."

0 commit comments

Comments
 (0)