Skip to content

Commit 410311b

Browse files
authored
[Feature Enhancement] Support multiple incorrect subgraphs. (#407)
* [Feature Enhancement] Support multiple incorrect subgraphs * support tolerance range * fix
1 parent 0eb02b6 commit 410311b

File tree

8 files changed

+887
-32
lines changed

8 files changed

+887
-32
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,29 @@ def convert_b64_string_to_json(b64str):
1818
return json.loads(base64.b64decode(b64str).decode("utf-8"))
1919

2020

21+
def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set:
22+
if not os.path.exists(log_path):
23+
return set()
24+
25+
t_start = tolerance_args[0]
26+
models_start = set(get_incorrect_models(t_start, log_path))
27+
28+
if len(tolerance_args) == 1:
29+
return models_start
30+
31+
t_end = tolerance_args[1]
32+
models_end = set(get_incorrect_models(t_end, log_path))
33+
34+
print(f"[Filter] Tolerance Range: {t_start} -> {t_end}")
35+
print(
36+
f"[Filter] Fail({t_start}): {len(models_start)}, Fail({t_end}): {len(models_end)}"
37+
)
38+
39+
diff_set = models_start - models_end
40+
41+
return diff_set
42+
43+
2144
class TaskController:
2245
def __init__(self, args):
2346
self.root_output_dir = os.path.abspath(args.output_dir)
@@ -203,10 +226,10 @@ def run_decomposer_for_multi_models(
203226
)
204227
for model_name, task_info in tasks_map.items():
205228
original_path = task_info["original_path"]
206-
split_positions = calculate_split_positions_for_subgraph(
207-
task_info["subgraph_size"], max_subgraph_size
208-
)
209-
task_info["split_positions"] = split_positions
229+
230+
split_positions = task_info["split_positions"]
231+
if isinstance(split_positions, set):
232+
split_positions = sorted(list(split_positions))
210233

211234
rectified_model_path = get_rectfied_model_path(original_path)
212235
assert os.path.exists(
@@ -282,19 +305,28 @@ def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
282305
def generate_initial_tasks(args):
283306
"""Generates tasks for Pass 0 based on the initial log file."""
284307
print(f"[Init] Pass 0: Reading from log file: {args.log_file}")
285-
initial_failures = get_incorrect_models(args.tolerance, args.log_file)
308+
initial_failures = get_ranged_incorrect_models(args.tolerance, args.log_file)
286309

287310
tasks_map = {}
311+
max_subgraph_size = args.max_subgraph_size
312+
288313
for model_path in initial_failures:
289314
model_name = get_model_name_with_subgraph_tag(model_path)
315+
316+
initial_range = [0, kMaxGraphSize]
317+
initial_splits = calculate_split_positions_for_subgraph(
318+
initial_range, max_subgraph_size
319+
)
320+
290321
tasks_map[model_name] = {
291322
"subgraph_path": model_path,
292323
"original_path": model_path,
293-
"subgraph_size": [0, kMaxGraphSize],
294-
"split_positions": set(),
324+
"split_positions": set(initial_splits),
295325
}
296326

297-
max_subgraph_size = args.max_subgraph_size
327+
for task in tasks_map.values():
328+
task["split_positions"] = sorted(list(task["split_positions"]))
329+
298330
return tasks_map, max_subgraph_size
299331

300332

@@ -307,7 +339,6 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
307339
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
308340
prev_tasks_map = prev_config.get("tasks_map", {})
309341

310-
# Load previous max size as fallback
311342
prev_max_subgraph_size = prev_config.get("max_subgraph_size")
312343
max_subgraph_size = prev_max_subgraph_size // 2
313344

@@ -324,20 +355,30 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
324355
assert model_name in prev_tasks_map
325356
pre_task_for_model = prev_tasks_map[model_name]
326357

327-
# Reconstruct previous subgraph size to locate the failing segment
328358
prev_split_positions = pre_task_for_model.get("split_positions", [])
329-
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
359+
subgraph_ranges = reconstruct_subgraph_size(prev_split_positions)
360+
330361
assert subgraph_idx < len(
331-
subgraph_size
362+
subgraph_ranges
332363
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
333364

365+
current_fail_range = subgraph_ranges[subgraph_idx]
366+
367+
new_splits = calculate_split_positions_for_subgraph(
368+
current_fail_range, max_subgraph_size
369+
)
370+
334371
if model_name not in tasks_map:
335372
tasks_map[model_name] = {
336373
"subgraph_path": subgraph_path,
337374
"original_path": pre_task_for_model["original_path"],
338-
"subgraph_size": subgraph_size[subgraph_idx],
339-
"split_positions": set(),
375+
"split_positions": set(new_splits),
340376
}
377+
else:
378+
tasks_map[model_name]["split_positions"].update(new_splits)
379+
380+
for task in tasks_map.values():
381+
task["split_positions"] = sorted(list(task["split_positions"]))
341382

342383
return tasks_map, max_subgraph_size
343384

@@ -402,11 +443,23 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
402443
need_decompose = True
403444
shutil.rmtree(decomposed_samples_dir)
404445
os.makedirs(decomposed_samples_dir, exist_ok=True)
446+
max_subgraph_size = max(1, max_subgraph_size // 2)
405447
for model_name, task_info in tasks_map.items():
406-
task_info["subgraph_size"][1] = (
407-
task_info["subgraph_size"][0] + max_subgraph_size
448+
splits = task_info["split_positions"]
449+
if not splits or len(splits) < 2:
450+
continue
451+
if isinstance(splits, set):
452+
splits = sorted(list(splits))
453+
start_pos = splits[0]
454+
first_segment_end = splits[1]
455+
new_splits = list(
456+
range(start_pos, first_segment_end + 1, max_subgraph_size)
408457
)
409-
max_subgraph_size = max(1, max_subgraph_size // 2)
458+
459+
if new_splits[-1] != first_segment_end:
460+
new_splits.append(first_segment_end)
461+
462+
task_info["split_positions"] = sorted(list(set(new_splits)))
410463
else:
411464
need_decompose = False
412465
print()
@@ -474,8 +527,18 @@ def main(args):
474527
next_round_models = set()
475528
if task_controller.task_scheduler["post_analysis"]:
476529
print("\n--- Phase 3: Analysis ---")
477-
next_round_models = get_incorrect_models(args.tolerance, pass_log_path)
530+
analysis_tolerance = (
531+
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
532+
)
533+
next_round_models = get_incorrect_models(analysis_tolerance, pass_log_path)
534+
478535
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
536+
if len(next_round_models) > 0:
537+
print("[DEBUG] List of detected incorrect models:")
538+
for idx, model_path in enumerate(sorted(list(next_round_models))):
539+
print(f" [{idx}] {model_path}")
540+
else:
541+
print("[DEBUG] No incorrect models detected.")
479542
print_summary_and_suggestion(next_round_models, max_subgraph_size)
480543

481544
# --- Step 5: Save States ---
@@ -497,7 +560,11 @@ def main(args):
497560
"--test-config", type=str, required=True, help="Base64 encoded test config"
498561
)
499562
parser.add_argument(
500-
"--tolerance", type=int, required=True, help="Tolerance level range [-10, 5)"
563+
"--tolerance",
564+
type=int,
565+
nargs="+",
566+
required=True,
567+
help="Tolerance level range [-10, 5)",
501568
)
502569
parser.add_argument("--max-subgraph-size", type=int, default=4096)
503570
args = parser.parse_args()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": false,
6+
"model_name": "error_model"
7+
}

graph_net/test/error_model/input_meta.py

Whitespace-only changes.

graph_net/test/error_model/input_tensor_constraints.py

Whitespace-only changes.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import torch
2+
3+
from torch import device
4+
5+
6+
class GraphModule(torch.nn.Module):
7+
def forward(
8+
self,
9+
add_22,
10+
extended_attention_mask_2,
11+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_,
12+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_,
13+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_,
14+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_,
15+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_,
16+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_,
17+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_,
18+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_,
19+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_,
20+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_,
21+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_,
22+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_,
23+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_,
24+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_,
25+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_,
26+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_,
27+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_,
28+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_,
29+
):
30+
hidden_states_66 = torch.nn.functional.layer_norm(
31+
add_22,
32+
(32,),
33+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_,
34+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_,
35+
1e-12,
36+
)
37+
add_22 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_ = (None)
38+
linear_44 = torch.nn.functional.linear(
39+
hidden_states_66,
40+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_,
41+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_,
42+
)
43+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_ = (None)
44+
view_16 = linear_44.view(2, -1, 4, 8)
45+
linear_44 = None
46+
query_layer_4 = view_16.transpose(1, 2)
47+
view_16 = None
48+
linear_45 = torch.nn.functional.linear(
49+
hidden_states_66,
50+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_,
51+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_,
52+
)
53+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_ = (None)
54+
view_17 = linear_45.view(2, -1, 4, 8)
55+
linear_45 = None
56+
key_layer_4 = view_17.transpose(1, 2)
57+
view_17 = None
58+
linear_46 = torch.nn.functional.linear(
59+
hidden_states_66,
60+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_,
61+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_,
62+
)
63+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_ = (None)
64+
view_18 = linear_46.view(2, -1, 4, 8)
65+
linear_46 = None
66+
value_layer_4 = view_18.transpose(1, 2)
67+
view_18 = None
68+
transpose_25 = key_layer_4.transpose(-1, -2)
69+
key_layer_4 = None
70+
attention_scores_22 = torch.matmul(query_layer_4, transpose_25)
71+
query_layer_4 = transpose_25 = None
72+
attention_scores_23 = attention_scores_22 / 2.8284271247461903
73+
attention_scores_22 = None
74+
eps = torch.tensor(1e-8, device=attention_scores_23.device)
75+
nan_val = eps / (eps - eps)
76+
attention_scores_23 = attention_scores_23 + nan_val
77+
nan_val = None
78+
to_8 = extended_attention_mask_2.to(device(type="cuda", index=0))
79+
extended_attention_mask_2 = None
80+
attention_scores_24 = attention_scores_23 + to_8
81+
attention_scores_23 = to_8 = None
82+
_log_api_usage_once_4 = torch._C._log_api_usage_once("python.nn_module")
83+
_log_api_usage_once_4 = None
84+
attention_probs_14 = torch.nn.functional.softmax(
85+
attention_scores_24, -1, _stacklevel=5
86+
)
87+
attention_scores_24 = None
88+
attention_probs_dropped_4 = torch.nn.functional.dropout(
89+
attention_probs_14, 0.0, False, False
90+
)
91+
attention_probs_14 = None
92+
context_layer_22 = torch.matmul(attention_probs_dropped_4, value_layer_4)
93+
attention_probs_dropped_4 = value_layer_4 = None
94+
permute_14 = context_layer_22.permute(0, 2, 1, 3)
95+
context_layer_22 = None
96+
context_layer_23 = permute_14.contiguous()
97+
permute_14 = None
98+
context_layer_24 = context_layer_23.view(2, 14, 32)
99+
context_layer_23 = None
100+
hidden_states_67 = torch.nn.functional.linear(
101+
context_layer_24,
102+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_,
103+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_,
104+
)
105+
context_layer_24 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_ = (None)
106+
hidden_states_68 = torch.nn.functional.dropout(
107+
hidden_states_67, 0.0, False, False
108+
)
109+
hidden_states_67 = None
110+
add_24 = hidden_states_68 + hidden_states_66
111+
hidden_states_68 = hidden_states_66 = None
112+
hidden_states_69 = torch.nn.functional.layer_norm(
113+
add_24,
114+
(32,),
115+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_,
116+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_,
117+
1e-12,
118+
)
119+
add_24 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_ = (None)
120+
hidden_states_70 = torch.nn.functional.linear(
121+
hidden_states_69,
122+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_,
123+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_,
124+
)
125+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_ = (None)
126+
hidden_states_71 = torch.nn.functional.gelu(hidden_states_70)
127+
hidden_states_70 = None
128+
hidden_states_72 = torch.nn.functional.linear(
129+
hidden_states_71,
130+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_,
131+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_,
132+
)
133+
hidden_states_71 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_ = (None)
134+
hidden_states_73 = torch.nn.functional.dropout(
135+
hidden_states_72, 0.0, False, False
136+
)
137+
hidden_states_72 = None
138+
nan_val = torch.tensor(0.0, device=hidden_states_73.device) / torch.tensor(
139+
0.0, device=hidden_states_73.device
140+
)
141+
hidden_states_73 = hidden_states_73 + nan_val
142+
nan_val = None
143+
add_25 = hidden_states_73 + hidden_states_69
144+
hidden_states_73 = hidden_states_69 = None
145+
hidden_states_74 = torch.nn.functional.layer_norm(
146+
add_25,
147+
(32,),
148+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_,
149+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_,
150+
1e-12,
151+
)
152+
add_25 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_ = (None)
153+
return (hidden_states_74,)

0 commit comments

Comments
 (0)