Skip to content

Commit a0df9d6

Browse files
committed
Optimize the definition of decompose config and fix the config saving of test_target_device.
1 parent bd84a36 commit a0df9d6

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,16 @@ def save_decompose_config(
123123
):
124124
"""Saves the current state to a JSON file."""
125125

126-
active_models_map = {}
127-
split_positions_map = {}
126+
tasks_map_copy = {}
128127
for model_name, task_info in tasks_map.items():
129-
active_models_map[model_name] = task_info["original_path"]
130-
split_positions_map[model_name] = task_info["split_positions"]
128+
tasks_map_copy[model_name] = {}
129+
for key in ["original_path", "split_positions"]:
130+
tasks_map_copy[model_name][key] = task_info[key]
131131

132132
config = {
133133
"max_subgraph_size": max_subgraph_size,
134134
"incorrect_models": list(incorrect_paths),
135-
"active_models_map": active_models_map,
136-
"split_positions_map": split_positions_map,
135+
"tasks_map": tasks_map_copy,
137136
"failed_decomposition_models": list(failed_decomposition_models),
138137
}
139138
config_path = get_decompose_config_path(work_dir)
@@ -283,9 +282,8 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
283282
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
284283

285284
prev_config = load_decompose_config(prev_pass_dir)
286-
prev_active_models_map = prev_config.get("active_models_map", {})
287-
prev_split_positions_map = prev_config.get("split_positions_map", {})
288285
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
286+
prev_tasks_map = prev_config.get("tasks_map", {})
289287

290288
# Load previous max size as fallback
291289
prev_max_subgraph_size = prev_config.get("max_subgraph_size")
@@ -294,20 +292,18 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
294292
if not prev_incorrect_subgraphs:
295293
return {}, max_subgraph_size
296294

297-
print("[Analysis] Refining splits based on previous incorrect models ...")
298-
299295
tasks_map = {}
300296
for subgraph_path in prev_incorrect_subgraphs:
301297
# Parse model name and subgraph index
302298
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
303299
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
304300
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
305-
print(f"subgraph_path: {subgraph_path}")
306-
print(f"model_name: {model_name}, subgraph_idx: {subgraph_idx}")
307-
assert model_name in prev_active_models_map
301+
302+
assert model_name in prev_tasks_map
303+
pre_task_for_model = prev_tasks_map[model_name]
308304

309305
# Reconstruct previous subgraph size to locate the failing segment
310-
prev_split_positions = prev_split_positions_map.get(model_name, [])
306+
prev_split_positions = pre_task_for_model.get("split_positions", [])
311307
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
312308
assert subgraph_idx < len(
313309
subgraph_size
@@ -316,7 +312,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
316312
if model_name not in tasks_map:
317313
tasks_map[model_name] = {
318314
"subgraph_path": subgraph_path,
319-
"original_path": prev_active_models_map[model_name],
315+
"original_path": pre_task_for_model["original_path"],
320316
"subgraph_size": subgraph_size[subgraph_idx],
321317
"split_positions": set(),
322318
}
@@ -338,7 +334,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
338334
)
339335
if not os.path.exists(decomposed_samples_dir):
340336
os.makedirs(decomposed_samples_dir, exist_ok=True)
341-
print(f"decomposed_samples_dir: {decomposed_samples_dir}")
337+
print(f"- decomposed_samples_dir: {decomposed_samples_dir}")
342338

343339
for model_name, task_info in tasks_map.items():
344340
print(f"[Decomposition] max_subgraph_size: {max_subgraph_size}")
@@ -387,8 +383,8 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
387383
return tasks_map, failed_decomposition, max_subgraph_size
388384

389385

390-
def print_final_summary(next_round_models, max_subgraph_size):
391-
"""Prints the final suggestion/result."""
386+
def print_summary_and_suggestion(next_round_models, max_subgraph_size):
387+
"""Print suggestion/result."""
392388
print("\n" + "=" * 80)
393389
if next_round_models and max_subgraph_size > 1:
394390
print(f">>> [SUGGESTION] Issues remain (Count: {len(next_round_models)}).")
@@ -447,6 +443,11 @@ def main(args):
447443
) = execute_decomposition_phase(
448444
max_subgraph_size, tasks_map, args.framework, pass_work_dir
449445
)
446+
else:
447+
config = load_decompose_config(pass_work_dir)
448+
max_subgraph_size = config["max_subgraph_size"]
449+
failed_decomposition = config["failed_decomposition_models"]
450+
tasks_map = config.get("tasks_map", {})
450451

451452
# --- Step 4: Testing ---
452453
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
@@ -470,7 +471,7 @@ def main(args):
470471
failed_decomposition,
471472
)
472473

473-
print_final_summary(next_round_models, max_subgraph_size)
474+
print_summary_and_suggestion(next_round_models, max_subgraph_size)
474475

475476

476477
if __name__ == "__main__":

0 commit comments

Comments
 (0)