Skip to content

Commit 81dde99

Browse files
committed
support tolerance range
1 parent 6d91582 commit 81dde99

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,30 @@ def convert_b64_string_to_json(b64str):
1818
return json.loads(base64.b64decode(b64str).decode("utf-8"))
1919

2020

21+
def get_filtered_incorrect_models(tolerance_args: List[int], log_path: str) -> set:
22+
if not os.path.exists(log_path):
23+
return set()
24+
25+
t_start = tolerance_args[0]
26+
models_start = set(get_incorrect_models(t_start, log_path))
27+
28+
if len(tolerance_args) == 1:
29+
return models_start
30+
31+
t_end = tolerance_args[1]
32+
models_end = set(get_incorrect_models(t_end, log_path))
33+
34+
print(f"[Filter] Tolerance Range: {t_start} -> {t_end}")
35+
print(
36+
f"[Filter] Fail({t_start}): {len(models_start)}, Fail({t_end}): {len(models_end)}"
37+
)
38+
39+
diff_set = models_start - models_end
40+
print(f"[Filter] Result (Difference): {len(diff_set)}")
41+
42+
return diff_set
43+
44+
2145
class TaskController:
2246
def __init__(self, args):
2347
self.root_output_dir = os.path.abspath(args.output_dir)
@@ -290,7 +314,7 @@ def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
290314
def generate_initial_tasks(args):
291315
"""Generates tasks for Pass 0 based on the initial log file."""
292316
print(f"[Init] Pass 0: Reading from log file: {args.log_file}")
293-
initial_failures = get_incorrect_models(args.tolerance, args.log_file)
317+
initial_failures = get_filtered_incorrect_models(args.tolerance, args.log_file)
294318

295319
tasks_map = {}
296320
for model_path in initial_failures:
@@ -487,7 +511,7 @@ def main(args):
487511
next_round_models = set()
488512
if task_controller.task_scheduler["post_analysis"]:
489513
print("\n--- Phase 3: Analysis ---")
490-
next_round_models = get_incorrect_models(args.tolerance, pass_log_path)
514+
next_round_models = get_filtered_incorrect_models(args.tolerance, pass_log_path)
491515
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
492516
if len(next_round_models) > 0:
493517
print("[DEBUG] List of detected incorrect models:")
@@ -516,7 +540,11 @@ def main(args):
516540
"--test-config", type=str, required=True, help="Base64 encoded test config"
517541
)
518542
parser.add_argument(
519-
"--tolerance", type=int, required=True, help="Tolerance level range [-10, 5)"
543+
"--tolerance",
544+
type=int,
545+
nargs="+",
546+
required=True,
547+
help="Tolerance level range [-10, 5)",
520548
)
521549
parser.add_argument("--max-subgraph-size", type=int, default=4096)
522550
args = parser.parse_args()

graph_net/test/subgraph_decompose_and_evaluation_step_test.sh

100644100755
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(
55
FRAMEWORK="torch"
66
LOG_FILE="$GRAPH_NET_ROOT/test/log_file_for_subgraph_decompose_and_evaluation_step.log"
77
OUTPUT_DIR="/tmp/decompose_and_evaluation_workspace"
8-
TOLERANCE=0
8+
TOLERANCE="0 2"
99
INITIAL_MAX_SIZE=2048
1010

1111
test_compiler_config_str=$(cat <<EOF
@@ -73,7 +73,7 @@ python3 -m graph_net.subgraph_decompose_and_evaluation_step \
7373
--output-dir="$OUTPUT_DIR" \
7474
--framework="${FRAMEWORK}" \
7575
--test-config="$TEST_CONFIG_B64" \
76-
--tolerance="$TOLERANCE" \
76+
--tolerance $TOLERANCE \
7777
--max-subgraph-size="$INITIAL_MAX_SIZE"
7878

7979
if [ $? -ne 0 ]; then

0 commit comments

Comments
 (0)