Skip to content

Commit 1eb203c

Browse files
committed
Merge branch 'develop' into opt_saved_results
2 parents 7d9581f + 410311b commit 1eb203c

17 files changed

+1149
-42
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,28 @@ def get_pass_name(pass_id):
2727
return f"pass_{pass_id}"
2828

2929

30+
def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set:
31+
if not os.path.exists(log_path):
32+
return set()
33+
34+
t_start = tolerance_args[0]
35+
models_start = set(get_incorrect_models(t_start, log_path))
36+
37+
if len(tolerance_args) == 1:
38+
return models_start
39+
40+
t_end = tolerance_args[1]
41+
models_end = set(get_incorrect_models(t_end, log_path))
42+
43+
print(f"[Filter] Tolerance Range: {t_start} -> {t_end}")
44+
print(
45+
f"[Filter] Fail({t_start}): {len(models_start)}, Fail({t_end}): {len(models_end)}"
46+
)
47+
48+
diff_set = models_start - models_end
49+
return diff_set
50+
51+
3052
class TaskController:
3153
def __init__(self, args):
3254
self.root_output_dir = os.path.abspath(args.output_dir)
@@ -198,10 +220,10 @@ def run_decomposer_for_multi_models(
198220
)
199221
for model_name, task_info in tasks_map.items():
200222
original_path = task_info["original_path"]
201-
split_positions = calculate_split_positions_for_subgraph(
202-
task_info["subgraph_size"], max_subgraph_size
203-
)
204-
task_info["split_positions"] = split_positions
223+
224+
split_positions = task_info["split_positions"]
225+
if isinstance(split_positions, set):
226+
split_positions = sorted(list(split_positions))
205227

206228
rectified_model_path = get_rectfied_model_path(original_path)
207229
assert os.path.exists(
@@ -269,28 +291,32 @@ def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
269291
start_pos, end_pos = subgraph_size
270292
end_pos = kMaxGraphSize if end_pos == float("inf") else end_pos
271293

272-
split_positions = list(range(start_pos, end_pos + 1, max_subgraph_size))
273-
deduplicated_splits = list(dict.fromkeys(split_positions))
294+
split_positions = set(range(start_pos, end_pos + 1, max_subgraph_size))
295+
deduplicated_splits = list(sorted(split_positions))
274296
return deduplicated_splits
275297

276298

277299
def generate_initial_tasks(args):
278300
"""Generates tasks for Pass 0 based on the initial log file."""
279301
print(f"[Init] Pass 0: Reading from log file: {args.log_file}")
280-
initial_failures = get_incorrect_models(args.tolerance, args.log_file)
281-
t1_incorrect_models = get_incorrect_models(1, args.log_file)
282-
initial_failures = initial_failures - t1_incorrect_models
302+
initial_failures = get_ranged_incorrect_models(args.tolerance, args.log_file)
283303

284304
tasks_map = {}
305+
max_subgraph_size = args.max_subgraph_size
306+
285307
for model_path in initial_failures:
286308
model_name = get_model_name_with_subgraph_tag(model_path)
309+
310+
initial_range = [0, kMaxGraphSize]
311+
initial_splits = calculate_split_positions_for_subgraph(
312+
initial_range, max_subgraph_size
313+
)
314+
287315
tasks_map[model_name] = {
288316
"original_path": model_path,
289-
"subgraph_size": [0, kMaxGraphSize],
290-
"split_positions": set(),
317+
"split_positions": list(sorted(initial_splits)),
291318
}
292319

293-
max_subgraph_size = args.max_subgraph_size
294320
running_states = {
295321
"pass_0": {
296322
"num_incorrect_models": len(initial_failures),
@@ -322,20 +348,26 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
322348
assert model_name in prev_tasks_map
323349
pre_task_for_model = prev_tasks_map[model_name]
324350

325-
# Reconstruct previous subgraph size to locate the failing segment
326351
prev_split_positions = pre_task_for_model.get("split_positions", [])
327-
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
352+
subgraph_ranges = reconstruct_subgraph_size(prev_split_positions)
353+
328354
assert subgraph_idx < len(
329-
subgraph_size
355+
subgraph_ranges
330356
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
331357

358+
current_fail_range = subgraph_ranges[subgraph_idx]
359+
360+
new_splits = calculate_split_positions_for_subgraph(
361+
current_fail_range, max_subgraph_size
362+
)
363+
332364
if model_name not in tasks_map:
333365
tasks_map[model_name] = {
334366
"original_path": pre_task_for_model["original_path"],
335-
"subgraph_size": subgraph_size[subgraph_idx],
336-
"split_positions": set(),
367+
"split_positions": list(sorted(new_splits)),
337368
}
338-
369+
else:
370+
tasks_map[model_name]["split_positions"] = list(sorted(new_splits))
339371
return tasks_map, max_subgraph_size, prev_config.running_states
340372

341373

@@ -399,11 +431,23 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
399431
need_decompose = True
400432
shutil.rmtree(decomposed_samples_dir)
401433
os.makedirs(decomposed_samples_dir, exist_ok=True)
434+
max_subgraph_size = max(1, max_subgraph_size // 2)
402435
for model_name, task_info in tasks_map.items():
403-
task_info["subgraph_size"][1] = (
404-
task_info["subgraph_size"][0] + max_subgraph_size
436+
splits = task_info["split_positions"]
437+
if not splits or len(splits) < 2:
438+
continue
439+
if isinstance(splits, set):
440+
splits = sorted(list(splits))
441+
start_pos = splits[0]
442+
first_segment_end = splits[1]
443+
new_splits = list(
444+
range(start_pos, first_segment_end + 1, max_subgraph_size)
405445
)
406-
max_subgraph_size = max(1, max_subgraph_size // 2)
446+
447+
if new_splits[-1] != first_segment_end:
448+
new_splits.append(first_segment_end)
449+
450+
task_info["split_positions"] = sorted(list(set(new_splits)))
407451
else:
408452
need_decompose = False
409453
print()
@@ -473,12 +517,20 @@ def main(args):
473517
next_round_models = set()
474518
if task_controller.task_scheduler["post_analysis"]:
475519
print("\n--- Phase 3: Analysis ---")
476-
next_round_models = sorted(get_incorrect_models(args.tolerance, pass_log_path))
520+
tolerance = (
521+
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
522+
)
523+
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
477524
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
478525
running_states[f"pass_{current_pass_id + 1}"] = {
479526
"num_incorrect_models": len(next_round_models),
480527
"incorrect_models": list(next_round_models),
481528
}
529+
530+
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
531+
for idx, model_path in enumerate(next_round_models):
532+
print(f"- [{idx}] {model_path}")
533+
482534
print_summary_and_suggestion(next_round_models, max_subgraph_size)
483535

484536
# --- Step 5: Save States ---
@@ -500,7 +552,11 @@ def main(args):
500552
"--test-config", type=str, required=True, help="Base64 encoded test config"
501553
)
502554
parser.add_argument(
503-
"--tolerance", type=int, required=True, help="Tolerance level range [-10, 5)"
555+
"--tolerance",
556+
type=int,
557+
nargs="+",
558+
required=True,
559+
help="Tolerance level range [-10, 5)",
504560
)
505561
parser.add_argument("--max-subgraph-size", type=int, default=4096)
506562
args = parser.parse_args()

graph_net/test/dimension_generalization_test.sh

100644100755
File mode changed.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": false,
6+
"model_name": "error_model"
7+
}

graph_net/test/error_model/input_meta.py

Whitespace-only changes.

graph_net/test/error_model/input_tensor_constraints.py

Whitespace-only changes.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import torch
2+
3+
from torch import device
4+
5+
6+
class GraphModule(torch.nn.Module):
7+
def forward(
8+
self,
9+
add_22,
10+
extended_attention_mask_2,
11+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_,
12+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_,
13+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_,
14+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_,
15+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_,
16+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_,
17+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_,
18+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_,
19+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_,
20+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_,
21+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_,
22+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_,
23+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_,
24+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_,
25+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_,
26+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_,
27+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_,
28+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_,
29+
):
30+
hidden_states_66 = torch.nn.functional.layer_norm(
31+
add_22,
32+
(32,),
33+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_,
34+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_,
35+
1e-12,
36+
)
37+
add_22 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_ = (None)
38+
linear_44 = torch.nn.functional.linear(
39+
hidden_states_66,
40+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_,
41+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_,
42+
)
43+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_ = (None)
44+
view_16 = linear_44.view(2, -1, 4, 8)
45+
linear_44 = None
46+
query_layer_4 = view_16.transpose(1, 2)
47+
view_16 = None
48+
linear_45 = torch.nn.functional.linear(
49+
hidden_states_66,
50+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_,
51+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_,
52+
)
53+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_ = (None)
54+
view_17 = linear_45.view(2, -1, 4, 8)
55+
linear_45 = None
56+
key_layer_4 = view_17.transpose(1, 2)
57+
view_17 = None
58+
linear_46 = torch.nn.functional.linear(
59+
hidden_states_66,
60+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_,
61+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_,
62+
)
63+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_ = (None)
64+
view_18 = linear_46.view(2, -1, 4, 8)
65+
linear_46 = None
66+
value_layer_4 = view_18.transpose(1, 2)
67+
view_18 = None
68+
transpose_25 = key_layer_4.transpose(-1, -2)
69+
key_layer_4 = None
70+
attention_scores_22 = torch.matmul(query_layer_4, transpose_25)
71+
query_layer_4 = transpose_25 = None
72+
attention_scores_23 = attention_scores_22 / 2.8284271247461903
73+
attention_scores_22 = None
74+
eps = torch.tensor(1e-8, device=attention_scores_23.device)
75+
nan_val = eps / (eps - eps)
76+
attention_scores_23 = attention_scores_23 + nan_val
77+
nan_val = None
78+
to_8 = extended_attention_mask_2.to(device(type="cuda", index=0))
79+
extended_attention_mask_2 = None
80+
attention_scores_24 = attention_scores_23 + to_8
81+
attention_scores_23 = to_8 = None
82+
_log_api_usage_once_4 = torch._C._log_api_usage_once("python.nn_module")
83+
_log_api_usage_once_4 = None
84+
attention_probs_14 = torch.nn.functional.softmax(
85+
attention_scores_24, -1, _stacklevel=5
86+
)
87+
attention_scores_24 = None
88+
attention_probs_dropped_4 = torch.nn.functional.dropout(
89+
attention_probs_14, 0.0, False, False
90+
)
91+
attention_probs_14 = None
92+
context_layer_22 = torch.matmul(attention_probs_dropped_4, value_layer_4)
93+
attention_probs_dropped_4 = value_layer_4 = None
94+
permute_14 = context_layer_22.permute(0, 2, 1, 3)
95+
context_layer_22 = None
96+
context_layer_23 = permute_14.contiguous()
97+
permute_14 = None
98+
context_layer_24 = context_layer_23.view(2, 14, 32)
99+
context_layer_23 = None
100+
hidden_states_67 = torch.nn.functional.linear(
101+
context_layer_24,
102+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_,
103+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_,
104+
)
105+
context_layer_24 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_ = (None)
106+
hidden_states_68 = torch.nn.functional.dropout(
107+
hidden_states_67, 0.0, False, False
108+
)
109+
hidden_states_67 = None
110+
add_24 = hidden_states_68 + hidden_states_66
111+
hidden_states_68 = hidden_states_66 = None
112+
hidden_states_69 = torch.nn.functional.layer_norm(
113+
add_24,
114+
(32,),
115+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_,
116+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_,
117+
1e-12,
118+
)
119+
add_24 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_ = (None)
120+
hidden_states_70 = torch.nn.functional.linear(
121+
hidden_states_69,
122+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_,
123+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_,
124+
)
125+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_ = (None)
126+
hidden_states_71 = torch.nn.functional.gelu(hidden_states_70)
127+
hidden_states_70 = None
128+
hidden_states_72 = torch.nn.functional.linear(
129+
hidden_states_71,
130+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_,
131+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_,
132+
)
133+
hidden_states_71 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_ = (None)
134+
hidden_states_73 = torch.nn.functional.dropout(
135+
hidden_states_72, 0.0, False, False
136+
)
137+
hidden_states_72 = None
138+
nan_val = torch.tensor(0.0, device=hidden_states_73.device) / torch.tensor(
139+
0.0, device=hidden_states_73.device
140+
)
141+
hidden_states_73 = hidden_states_73 + nan_val
142+
nan_val = None
143+
add_25 = hidden_states_73 + hidden_states_69
144+
hidden_states_73 = hidden_states_69 = None
145+
hidden_states_74 = torch.nn.functional.layer_norm(
146+
add_25,
147+
(32,),
148+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_,
149+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_,
150+
1e-12,
151+
)
152+
add_25 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_ = (None)
153+
return (hidden_states_74,)

0 commit comments

Comments
 (0)