Skip to content

Commit 6d91582

Browse files
committed
[Feature Enhancement] Support multiple incorrect subgraphs
1 parent 0aa1827 commit 6d91582

File tree

7 files changed

+830
-23
lines changed

7 files changed

+830
-23
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,18 @@ def run_decomposer_for_multi_models(
201201
print(
202202
f"[Decomposition] max_subgraph_size: {max_subgraph_size}, log_path: {log_path}"
203203
)
204+
204205
for model_name, task_info in tasks_map.items():
205206
original_path = task_info["original_path"]
206-
split_positions = calculate_split_positions_for_subgraph(
207-
task_info["subgraph_size"], max_subgraph_size
208-
)
207+
split_positions = []
208+
ranges = task_info["subgraph_sizes"]
209+
210+
for rng in ranges:
211+
splits = calculate_split_positions_for_subgraph(rng, max_subgraph_size)
212+
split_positions.extend(splits)
213+
214+
# Deduplicate and sort
215+
split_positions = sorted(list(set(split_positions)))
209216
task_info["split_positions"] = split_positions
210217

211218
rectified_model_path = get_rectfied_model_path(original_path)
@@ -222,6 +229,7 @@ def run_decomposer_for_multi_models(
222229
)
223230
if not success:
224231
failed_decomposition.append(rectified_model_path)
232+
225233
return tasks_map, failed_decomposition
226234

227235

@@ -290,7 +298,7 @@ def generate_initial_tasks(args):
290298
tasks_map[model_name] = {
291299
"subgraph_path": model_path,
292300
"original_path": model_path,
293-
"subgraph_size": [0, kMaxGraphSize],
301+
"subgraph_sizes": [[0, kMaxGraphSize]],
294302
"split_positions": set(),
295303
}
296304

@@ -307,7 +315,6 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
307315
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
308316
prev_tasks_map = prev_config.get("tasks_map", {})
309317

310-
# Load previous max size as fallback
311318
prev_max_subgraph_size = prev_config.get("max_subgraph_size")
312319
max_subgraph_size = prev_max_subgraph_size // 2
313320

@@ -324,20 +331,24 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
324331
assert model_name in prev_tasks_map
325332
pre_task_for_model = prev_tasks_map[model_name]
326333

327-
# Reconstruct previous subgraph size to locate the failing segment
328334
prev_split_positions = pre_task_for_model.get("split_positions", [])
329-
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
335+
subgraph_sizes = reconstruct_subgraph_size(prev_split_positions)
336+
330337
assert subgraph_idx < len(
331-
subgraph_size
338+
subgraph_sizes
332339
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
333340

341+
current_fail_range = subgraph_sizes[subgraph_idx]
342+
334343
if model_name not in tasks_map:
335344
tasks_map[model_name] = {
336345
"subgraph_path": subgraph_path,
337346
"original_path": pre_task_for_model["original_path"],
338-
"subgraph_size": subgraph_size[subgraph_idx],
347+
"subgraph_sizes": [current_fail_range],
339348
"split_positions": set(),
340349
}
350+
else:
351+
tasks_map[model_name]["subgraph_sizes"].append(current_fail_range)
341352

342353
return tasks_map, max_subgraph_size
343354

@@ -403,9 +414,11 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
403414
shutil.rmtree(decomposed_samples_dir)
404415
os.makedirs(decomposed_samples_dir, exist_ok=True)
405416
for model_name, task_info in tasks_map.items():
406-
task_info["subgraph_size"][1] = (
407-
task_info["subgraph_size"][0] + max_subgraph_size
408-
)
417+
for i in range(len(task_info["subgraph_sizes"])):
418+
# Attempt to expand the end position for retry
419+
task_info["subgraph_sizes"][i][1] = (
420+
task_info["subgraph_sizes"][i][0] + max_subgraph_size
421+
)
409422
max_subgraph_size = max(1, max_subgraph_size // 2)
410423
else:
411424
need_decompose = False
@@ -476,6 +489,12 @@ def main(args):
476489
print("\n--- Phase 3: Analysis ---")
477490
next_round_models = get_incorrect_models(args.tolerance, pass_log_path)
478491
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
492+
if len(next_round_models) > 0:
493+
print("[DEBUG] List of detected incorrect models:")
494+
for idx, model_path in enumerate(sorted(list(next_round_models))):
495+
print(f" [{idx}] {model_path}")
496+
else:
497+
print("[DEBUG] No incorrect models detected.")
479498
print_summary_and_suggestion(next_round_models, max_subgraph_size)
480499

481500
# --- Step 5: Save States ---
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": false,
6+
"model_name": "error_model"
7+
}

graph_net/test/error_model/input_meta.py

Whitespace-only changes.

graph_net/test/error_model/input_tensor_constraints.py

Whitespace-only changes.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import torch
2+
3+
from torch import device
4+
5+
6+
class GraphModule(torch.nn.Module):
7+
def forward(
8+
self,
9+
add_22,
10+
extended_attention_mask_2,
11+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_,
12+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_,
13+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_,
14+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_,
15+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_,
16+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_,
17+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_,
18+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_,
19+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_,
20+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_,
21+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_,
22+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_,
23+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_,
24+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_,
25+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_,
26+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_,
27+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_,
28+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_,
29+
):
30+
hidden_states_66 = torch.nn.functional.layer_norm(
31+
add_22,
32+
(32,),
33+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_,
34+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_,
35+
1e-12,
36+
)
37+
add_22 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_ = (None)
38+
linear_44 = torch.nn.functional.linear(
39+
hidden_states_66,
40+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_,
41+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_,
42+
)
43+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_ = (None)
44+
view_16 = linear_44.view(2, -1, 4, 8)
45+
linear_44 = None
46+
query_layer_4 = view_16.transpose(1, 2)
47+
view_16 = None
48+
linear_45 = torch.nn.functional.linear(
49+
hidden_states_66,
50+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_,
51+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_,
52+
)
53+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_ = (None)
54+
view_17 = linear_45.view(2, -1, 4, 8)
55+
linear_45 = None
56+
key_layer_4 = view_17.transpose(1, 2)
57+
view_17 = None
58+
linear_46 = torch.nn.functional.linear(
59+
hidden_states_66,
60+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_,
61+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_,
62+
)
63+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_ = (None)
64+
view_18 = linear_46.view(2, -1, 4, 8)
65+
linear_46 = None
66+
value_layer_4 = view_18.transpose(1, 2)
67+
view_18 = None
68+
transpose_25 = key_layer_4.transpose(-1, -2)
69+
key_layer_4 = None
70+
attention_scores_22 = torch.matmul(query_layer_4, transpose_25)
71+
query_layer_4 = transpose_25 = None
72+
attention_scores_23 = attention_scores_22 / 2.8284271247461903
73+
attention_scores_22 = None
74+
eps = torch.tensor(1e-8, device=attention_scores_23.device)
75+
nan_val = eps / (eps - eps)
76+
attention_scores_23 = attention_scores_23 + nan_val
77+
nan_val = None
78+
to_8 = extended_attention_mask_2.to(device(type="cuda", index=0))
79+
extended_attention_mask_2 = None
80+
attention_scores_24 = attention_scores_23 + to_8
81+
attention_scores_23 = to_8 = None
82+
_log_api_usage_once_4 = torch._C._log_api_usage_once("python.nn_module")
83+
_log_api_usage_once_4 = None
84+
attention_probs_14 = torch.nn.functional.softmax(
85+
attention_scores_24, -1, _stacklevel=5
86+
)
87+
attention_scores_24 = None
88+
attention_probs_dropped_4 = torch.nn.functional.dropout(
89+
attention_probs_14, 0.0, False, False
90+
)
91+
attention_probs_14 = None
92+
context_layer_22 = torch.matmul(attention_probs_dropped_4, value_layer_4)
93+
attention_probs_dropped_4 = value_layer_4 = None
94+
permute_14 = context_layer_22.permute(0, 2, 1, 3)
95+
context_layer_22 = None
96+
context_layer_23 = permute_14.contiguous()
97+
permute_14 = None
98+
context_layer_24 = context_layer_23.view(2, 14, 32)
99+
context_layer_23 = None
100+
hidden_states_67 = torch.nn.functional.linear(
101+
context_layer_24,
102+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_,
103+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_,
104+
)
105+
context_layer_24 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_ = (None)
106+
hidden_states_68 = torch.nn.functional.dropout(
107+
hidden_states_67, 0.0, False, False
108+
)
109+
hidden_states_67 = None
110+
add_24 = hidden_states_68 + hidden_states_66
111+
hidden_states_68 = hidden_states_66 = None
112+
hidden_states_69 = torch.nn.functional.layer_norm(
113+
add_24,
114+
(32,),
115+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_,
116+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_,
117+
1e-12,
118+
)
119+
add_24 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_ = (None)
120+
hidden_states_70 = torch.nn.functional.linear(
121+
hidden_states_69,
122+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_,
123+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_,
124+
)
125+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_ = (None)
126+
hidden_states_71 = torch.nn.functional.gelu(hidden_states_70)
127+
hidden_states_70 = None
128+
hidden_states_72 = torch.nn.functional.linear(
129+
hidden_states_71,
130+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_,
131+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_,
132+
)
133+
hidden_states_71 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_ = (None)
134+
hidden_states_73 = torch.nn.functional.dropout(
135+
hidden_states_72, 0.0, False, False
136+
)
137+
hidden_states_72 = None
138+
nan_val = torch.tensor(0.0, device=hidden_states_73.device) / torch.tensor(
139+
0.0, device=hidden_states_73.device
140+
)
141+
hidden_states_73 = hidden_states_73 + nan_val
142+
nan_val = None
143+
add_25 = hidden_states_73 + hidden_states_69
144+
hidden_states_73 = hidden_states_69 = None
145+
hidden_states_74 = torch.nn.functional.layer_norm(
146+
add_25,
147+
(32,),
148+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_,
149+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_,
150+
1e-12,
151+
)
152+
add_25 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_ = (None)
153+
return (hidden_states_74,)

0 commit comments

Comments
 (0)