Skip to content

Commit ef7d4b6

Browse files
committed
return set[str]
1 parent 8c8070b commit ef7d4b6

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

graph_net/analysis_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]:
620620
return is_correct, None if is_correct else "accuracy"
621621

622622

623-
def get_incorrect_models(tolerance, log_file_path) -> list:
623+
def get_incorrect_models(tolerance, log_file_path) -> set[str]:
624624
"""
625625
Filters and returns models with accuracy issues based on given tolerance threshold.
626626
@@ -632,12 +632,12 @@ def get_incorrect_models(tolerance, log_file_path) -> list:
632632
log_file_path (str): Path to the log file containing model test results
633633
634634
Returns:
635-
failed_models[str]: names of models failing accuracy check, empty list if none found
635+
failed_models(str): names of models failing accuracy check, empty set if none found
636636
"""
637-
failed_models = []
637+
failed_models = set()
638638
datalist = parse_logs_to_data(log_file_path)
639639
for i in datalist:
640640
iscorrect, err = check_sample_correctness(i, tolerance)
641641
if not iscorrect:
642-
failed_models.append(i.get("model_path"))
642+
failed_models.add(i.get("model_path"))
643643
return failed_models

graph_net/test/get_incorrect_models_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")
99
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
1010

1111
TOLERANCE_LIST=(-2 -1 0 1 2)
12-
LOG_FILE_PATH="/work/.BCloud/log_20251013_175058_torch_inductor_full.log"
12+
LOG_FILE_PATH="your/log/file/path"
1313

1414
python3 - <<END
1515
from graph_net import analysis_util
1616
17-
result = list(analysis_util.get_incorrect_models($TOLERANCE_LIST, '$LOG_FILE_PATH'))
17+
result = analysis_util.get_incorrect_models($TOLERANCE_LIST, '$LOG_FILE_PATH')
1818
1919
for item in result:
2020
print(item)

0 commit comments

Comments
 (0)