Skip to content

Commit 8abea5d

Browse files
authored
add get_incorrect_models (#375)
* 1119 * 1120 * 1120.2 * model_path * remove unnecessary files and pre-committed * remove unnecessary files and pre-committed * 1121 remove unnecessary files * modify rev version * modify rev version * modify rev version * accuracy issues targeted * test script and modify feature * return set[str] * add logfile for test
1 parent 4e84a5c commit 8abea5d

File tree

5 files changed

+1337
-1
lines changed

5 files changed

+1337
-1
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ repos:
33
rev: 23.1.0
44
hooks:
55
- id: black
6-
language_version: python3
6+
language_version: python3

graph_net/analysis_util.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def parse_logs_to_data(log_file: str) -> list:
121121
current_run_key = processing_match.group(1).strip()
122122
# Initialize a nested dictionary structure for this new run
123123
all_runs_data[current_run_key] = {
124+
"model_path": line.split()[-1],
124125
"configuration": {},
125126
"correctness": {},
126127
"performance": {
@@ -617,3 +618,26 @@ def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]:
617618
)
618619

619620
return is_correct, None if is_correct else "accuracy"
621+
622+
623+
def get_incorrect_models(tolerance, log_file_path) -> set[str]:
624+
"""
625+
Filters and returns models with accuracy issues based on given tolerance threshold.
626+
627+
Parses model data from log file and checks each model's accuracy against the specified
628+
tolerance threshold. Returns paths of all models that fail to meet the accuracy requirements.
629+
630+
Args:
631+
tolerance (float): Accuracy tolerance threshold for model validation
632+
log_file_path (str): Path to the log file containing model test results
633+
634+
Returns:
635+
failed_models(str): names of models failing accuracy check, empty set if none found
636+
"""
637+
failed_models = set()
638+
datalist = parse_logs_to_data(log_file_path)
639+
for i in datalist:
640+
iscorrect, err = check_sample_correctness(i, tolerance)
641+
if not iscorrect:
642+
failed_models.add(i.get("model_path"))
643+
return failed_models
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
3+
4+
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
5+
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
6+
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")
7+
8+
# 将项目根目录加入Python路径
9+
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
10+
11+
TOLERANCE_LIST=(-2 -1 0 1 2)
12+
LOG_FILE_PATH="log_file_for_test.txt"
13+
14+
python3 - <<END
15+
from graph_net import analysis_util
16+
17+
result = analysis_util.get_incorrect_models($TOLERANCE_LIST, '$LOG_FILE_PATH')
18+
19+
for item in result:
20+
print(item)
21+
END

0 commit comments

Comments
 (0)