Skip to content
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ repos:
rev: 23.1.0
hooks:
- id: black
language_version: python3
language_version: python3
24 changes: 24 additions & 0 deletions graph_net/analysis_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def parse_logs_to_data(log_file: str) -> list:
current_run_key = processing_match.group(1).strip()
# Initialize a nested dictionary structure for this new run
all_runs_data[current_run_key] = {
"model_path": line.split()[-1],
"configuration": {},
"correctness": {},
"performance": {
Expand Down Expand Up @@ -617,3 +618,26 @@ def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]:
)

return is_correct, None if is_correct else "accuracy"


def get_incorrect_models(tolerance, log_file_path) -> set[str]:
"""
Filters and returns models with accuracy issues based on given tolerance threshold.

Parses model data from log file and checks each model's accuracy against the specified
tolerance threshold. Returns paths of all models that fail to meet the accuracy requirements.

Args:
tolerance (float): Accuracy tolerance threshold for model validation
log_file_path (str): Path to the log file containing model test results

Returns:
failed_models(str): names of models failing accuracy check, empty set if none found
"""
failed_models = set()
datalist = parse_logs_to_data(log_file_path)
for i in datalist:
iscorrect, err = check_sample_correctness(i, tolerance)
if not iscorrect:
failed_models.add(i.get("model_path"))
return failed_models
21 changes: 21 additions & 0 deletions graph_net/test/get_incorrect_models_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash


SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")

# 将项目根目录加入Python路径
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"

TOLERANCE_LIST=(-2 -1 0 1 2)
LOG_FILE_PATH="log_file_for_test.txt"

python3 - <<END
from graph_net import analysis_util

result = analysis_util.get_incorrect_models($TOLERANCE_LIST, '$LOG_FILE_PATH')

for item in result:
print(item)
END
Loading