Skip to content

Commit 4f0b203

Browse files
committed
Fix error and set random seed in run_model.
1 parent 6c46a71 commit 4f0b203

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

graph_net/analysis_util.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from graph_net.config.datatype_tolerance_config import get_precision
66
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
77
from graph_net.verify_aggregated_params import determine_tolerances
8+
from graph_net.positive_tolerance_interpretation_manager import (
9+
get_positive_tolerance_interpretation,
10+
)
811

912

1013
def detect_sample_status(log_text: str) -> str:
@@ -430,7 +433,10 @@ def check_sample_correctness(sample: dict, tolerance: int) -> tuple[bool, str]:
430433

431434

432435
def get_incorrect_models(
433-
tolerance: int, log_file_path: str, type: str = "ESt"
436+
tolerance: int,
437+
log_file_path: str,
438+
type: str = "ESt",
439+
positive_tolerance_interpretation_type: str = "default",
434440
) -> set[str]:
435441
"""
436442
Filters and returns models with accuracy issues based on given tolerance threshold.
@@ -459,9 +465,15 @@ def get_incorrect_models(
459465
is_correct_at_t1[idx] = is_correct
460466
fail_type_at_t1[idx] = fail_type if fail_type is not None else "correct"
461467

468+
positive_tolerance_interpretation = get_positive_tolerance_interpretation(
469+
positive_tolerance_interpretation_type
470+
)
471+
462472
for idx, sample in enumerate(samples):
463473
if not is_correct_at_t1[idx]:
464-
current_correctness = fake_perf_degrad(tolerance, fail_type_at_t1[idx])
474+
current_correctness = fake_perf_degrad(
475+
tolerance, fail_type_at_t1[idx], positive_tolerance_interpretation
476+
)
465477
failed_models.add(
466478
sample.get("model_path")
467479
) if current_correctness != "correct" else None

graph_net/paddle/run_model.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import json
33
import base64
44
import argparse
5+
import numpy as np
6+
import random
57

68
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
79

@@ -10,6 +12,12 @@
1012
from graph_net.paddle import utils
1113

1214

15+
def set_seed(random_seed):
16+
paddle.seed(random_seed)
17+
random.seed(random_seed)
18+
np.random.seed(random_seed)
19+
20+
1321
def load_class_from_file(file_path: str, class_name: str):
1422
print(f"Load {class_name} from {file_path}")
1523
module = imp_util.load_module(file_path, "unnamed")
@@ -23,12 +31,17 @@ def get_input_dict(model_path):
2331
params = inputs_params["weight_info"]
2432
inputs = inputs_params["input_info"]
2533

26-
state_dict = {}
27-
for k, v in params.items():
28-
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=k)
29-
for k, v in inputs.items():
30-
state_dict[k] = utils.replay_tensor(v)
31-
return state_dict
34+
input_dict = {}
35+
for name, meta in params.items():
36+
original_name = (
37+
meta["original_name"] if meta.get("original_name", None) else name
38+
)
39+
input_dict[name] = paddle.nn.parameter.Parameter(
40+
utils.replay_tensor(meta), name=original_name
41+
)
42+
for name, meta in inputs.items():
43+
input_dict[name] = utils.replay_tensor(meta)
44+
return input_dict
3245

3346

3447
def _convert_to_dict(config_str):
@@ -53,6 +66,9 @@ def _get_decorator(args):
5366

5467

5568
def main(args):
69+
initalize_seed = 123
70+
set_seed(random_seed=initalize_seed)
71+
5672
model_path = args.model_path
5773
model_class = load_class_from_file(
5874
f"{model_path}/model.py", class_name="GraphModule"

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class ModelRecord:
126126
original_path: str
127127
uniform_split_positions: List[int] = field(default_factory=list)
128128
subgraph_paths: List[str] = field(default_factory=list)
129-
incorrect_subgraph_idxs: List[int] = field(default_factory=list)
129+
incorrect_subgraph_idxs: List[int] = None
130130

131131
def get_split_positions(self, decompose_method):
132132
if decompose_method == "fixed-start":
@@ -466,7 +466,9 @@ def generate_initial_tasks(args):
466466
)
467467
decompose_config.update_running_state(
468468
pass_id=-1,
469-
running_state=RunningState(incorrect_models_from_log=initial_incorrect_models),
469+
running_state=RunningState(
470+
incorrect_models_from_log=list(sorted(initial_incorrect_models))
471+
),
470472
)
471473
decompose_config.update_running_state(
472474
pass_id=0, running_state=RunningState(model_name2record=model_name2record)

0 commit comments

Comments
 (0)