Skip to content

Commit 69d92a5

Browse files
committed
Fix error.
1 parent 6c46a71 commit 69d92a5

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

graph_net/analysis_util.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from graph_net.config.datatype_tolerance_config import get_precision
66
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
77
from graph_net.verify_aggregated_params import determine_tolerances
8+
from graph_net.positive_tolerance_interpretation_manager import (
9+
get_positive_tolerance_interpretation,
10+
)
811

912

1013
def detect_sample_status(log_text: str) -> str:
@@ -430,7 +433,10 @@ def check_sample_correctness(sample: dict, tolerance: int) -> tuple[bool, str]:
430433

431434

432435
def get_incorrect_models(
433-
tolerance: int, log_file_path: str, type: str = "ESt"
436+
tolerance: int,
437+
log_file_path: str,
438+
type: str = "ESt",
439+
positive_tolerance_interpretation_type: str = "default",
434440
) -> set[str]:
435441
"""
436442
Filters and returns models with accuracy issues based on given tolerance threshold.
@@ -459,9 +465,15 @@ def get_incorrect_models(
459465
is_correct_at_t1[idx] = is_correct
460466
fail_type_at_t1[idx] = fail_type if fail_type is not None else "correct"
461467

468+
positive_tolerance_interpretation = get_positive_tolerance_interpretation(
469+
positive_tolerance_interpretation_type
470+
)
471+
462472
for idx, sample in enumerate(samples):
463473
if not is_correct_at_t1[idx]:
464-
current_correctness = fake_perf_degrad(tolerance, fail_type_at_t1[idx])
474+
current_correctness = fake_perf_degrad(
475+
tolerance, fail_type_at_t1[idx], positive_tolerance_interpretation
476+
)
465477
failed_models.add(
466478
sample.get("model_path")
467479
) if current_correctness != "correct" else None

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class ModelRecord:
126126
original_path: str
127127
uniform_split_positions: List[int] = field(default_factory=list)
128128
subgraph_paths: List[str] = field(default_factory=list)
129-
incorrect_subgraph_idxs: List[int] = field(default_factory=list)
129+
incorrect_subgraph_idxs: List[int] = None
130130

131131
def get_split_positions(self, decompose_method):
132132
if decompose_method == "fixed-start":
@@ -466,7 +466,9 @@ def generate_initial_tasks(args):
466466
)
467467
decompose_config.update_running_state(
468468
pass_id=-1,
469-
running_state=RunningState(incorrect_models_from_log=initial_incorrect_models),
469+
running_state=RunningState(
470+
incorrect_models_from_log=list(sorted(initial_incorrect_models))
471+
),
470472
)
471473
decompose_config.update_running_state(
472474
pass_id=0, running_state=RunningState(model_name2record=model_name2record)

0 commit comments

Comments
 (0)