Skip to content

Commit 99cc4d7

Browse files
committed
Optimize the definition of decompose config and fix the config saving of test_target_device.
1 parent bd84a36 commit 99cc4d7

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,16 @@ def save_decompose_config(
123123
):
124124
"""Saves the current state to a JSON file."""
125125

126-
active_models_map = {}
127-
split_positions_map = {}
126+
tasks_map_copy = {}
128127
for model_name, task_info in tasks_map.items():
129-
active_models_map[model_name] = task_info["original_path"]
130-
split_positions_map[model_name] = task_info["split_positions"]
128+
tasks_map_copy[model_name] = {}
129+
for key in ["original_path", "split_positions"]:
130+
tasks_map_copy[model_name][key] = task_info[key]
131131

132132
config = {
133133
"max_subgraph_size": max_subgraph_size,
134134
"incorrect_models": list(incorrect_paths),
135-
"active_models_map": active_models_map,
136-
"split_positions_map": split_positions_map,
135+
"tasks_map": tasks_map_copy,
137136
"failed_decomposition_models": list(failed_decomposition_models),
138137
}
139138
config_path = get_decompose_config_path(work_dir)
@@ -283,9 +282,8 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
283282
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
284283

285284
prev_config = load_decompose_config(prev_pass_dir)
286-
prev_active_models_map = prev_config.get("active_models_map", {})
287-
prev_split_positions_map = prev_config.get("split_positions_map", {})
288285
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
286+
prev_tasks_map = prev_config.get("tasks_map", {})
289287

290288
# Load previous max size as fallback
291289
prev_max_subgraph_size = prev_config.get("max_subgraph_size")
@@ -302,12 +300,12 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
302300
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
303301
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
304302
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
305-
print(f"subgraph_path: {subgraph_path}")
306-
print(f"model_name: {model_name}, subgraph_idx: {subgraph_idx}")
307-
assert model_name in prev_active_models_map
303+
304+
assert model_name in prev_tasks_map
305+
pre_task_for_model = prev_tasks_map[model_name]
308306

309307
# Reconstruct previous subgraph size to locate the failing segment
310-
prev_split_positions = prev_split_positions_map.get(model_name, [])
308+
prev_split_positions = pre_task_for_model.get("split_positions", [])
311309
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
312310
assert subgraph_idx < len(
313311
subgraph_size
@@ -316,7 +314,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
316314
if model_name not in tasks_map:
317315
tasks_map[model_name] = {
318316
"subgraph_path": subgraph_path,
319-
"original_path": prev_active_models_map[model_name],
317+
"original_path": pre_task_for_model["original_path"],
320318
"subgraph_size": subgraph_size[subgraph_idx],
321319
"split_positions": set(),
322320
}
@@ -447,6 +445,11 @@ def main(args):
447445
) = execute_decomposition_phase(
448446
max_subgraph_size, tasks_map, args.framework, pass_work_dir
449447
)
448+
else:
449+
config = load_decompose_config(pass_work_dir)
450+
max_subgraph_size = config["max_subgraph_size"]
451+
failed_decomposition = config["failed_decomposition_models"]
452+
tasks_map = config.get("tasks_map", {})
450453

451454
# --- Step 4: Testing ---
452455
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")

0 commit comments

Comments
 (0)