|
5 | 5 | from graph_net.config.datatype_tolerance_config import get_precision |
6 | 6 | from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation |
7 | 7 | from graph_net.verify_aggregated_params import determine_tolerances |
| 8 | +from graph_net.positive_tolerance_interpretation_manager import ( |
| 9 | + get_positive_tolerance_interpretation, |
| 10 | +) |
8 | 11 |
|
9 | 12 |
|
10 | 13 | def detect_sample_status(log_text: str) -> str: |
@@ -430,7 +433,10 @@ def check_sample_correctness(sample: dict, tolerance: int) -> tuple[bool, str]: |
430 | 433 |
|
431 | 434 |
|
432 | 435 | def get_incorrect_models( |
433 | | - tolerance: int, log_file_path: str, type: str = "ESt" |
| 436 | + tolerance: int, |
| 437 | + log_file_path: str, |
| 438 | + type: str = "ESt", |
| 439 | + positive_tolerance_interpretation_type: str = "default", |
434 | 440 | ) -> set[str]: |
435 | 441 | """ |
436 | 442 | Filters and returns models with accuracy issues based on given tolerance threshold. |
@@ -459,9 +465,15 @@ def get_incorrect_models( |
459 | 465 | is_correct_at_t1[idx] = is_correct |
460 | 466 | fail_type_at_t1[idx] = fail_type if fail_type is not None else "correct" |
461 | 467 |
|
| 468 | + positive_tolerance_interpretation = get_positive_tolerance_interpretation( |
| 469 | + positive_tolerance_interpretation_type |
| 470 | + ) |
| 471 | + |
462 | 472 | for idx, sample in enumerate(samples): |
463 | 473 | if not is_correct_at_t1[idx]: |
464 | | - current_correctness = fake_perf_degrad(tolerance, fail_type_at_t1[idx]) |
| 474 | + current_correctness = fake_perf_degrad( |
| 475 | + tolerance, fail_type_at_t1[idx], positive_tolerance_interpretation |
| 476 | + ) |
465 | 477 | failed_models.add( |
466 | 478 | sample.get("model_path") |
467 | 479 | ) if current_correctness != "correct" else None |
|
0 commit comments