diff --git a/bin/config.py b/bin/config.py index 6e7d70b8..28241b42 100644 --- a/bin/config.py +++ b/bin/config.py @@ -129,6 +129,7 @@ def __init__(self, source: str | Path, **kwargs: Any) -> None: def warn(msg: Any) -> None: global n_warn + # `config` is imported before `util`, so we cannot use a `PrintBar` or `eprint` here. print(f"{Fore.YELLOW}WARNING: {msg}{Style.RESET_ALL}", file=sys.stderr) n_warn += 1 @@ -271,15 +272,12 @@ def get_arg(key: str, default: T, constraint: Optional[str] = None) -> T: print(key, type(kwargs[key])) warn(f"unknown key in {source}: '{key}'") - def update(self, args: "ARGS", replace: bool = False) -> None: + def add_if_not_set(self, args: "ARGS") -> None: for key in args._set: - if key not in self._set or replace: + if key not in self._set: setattr(self, key, getattr(args, key)) self._set.add(key) - def mark_set(self, *keys: str) -> None: - self._set.update(list(keys)) - def copy(self) -> "ARGS": res = copy.copy(self) res._set = copy.copy(res._set) diff --git a/bin/tools.py b/bin/tools.py index 5fd62dba..3efe1504 100755 --- a/bin/tools.py +++ b/bin/tools.py @@ -1058,7 +1058,7 @@ def build_parser() -> SuppressingParser: return parser -def find_personal_config() -> Optional[Path]: +def find_home_config_dir() -> Optional[Path]: if is_windows(): app_data = os.getenv("AppData") return Path(app_data) if app_data else None @@ -1071,14 +1071,14 @@ def find_personal_config() -> Optional[Path]: def read_personal_config(problem_dir: Optional[Path]) -> None: - home_config = find_personal_config() + home_config_dir = find_home_config_dir() # possible config files, sorted by priority config_files = [] if problem_dir: config_files.append(problem_dir / ".bapctools.yaml") config_files.append(Path.cwd() / ".bapctools.yaml") - if home_config: - config_files.append(home_config / "bapctools" / "config.yaml") + if home_config_dir: + config_files.append(home_config_dir / "bapctools" / "config.yaml") for config_file in config_files: if not config_file.is_file(): @@ -1091,8 +1091,7 @@ def read_personal_config(problem_dir: Optional[Path]) -> None: warn(f"invalid data in {config_data}. SKIPPED.") continue - tmp = config.ARGS(config_file, **config_data) - config.args.update(tmp) + config.args.add_if_not_set(config.ARGS(config_file, **config_data)) def run_parsed_arguments(args: argparse.Namespace, personal_config: bool = True) -> None: diff --git a/test/test_default_output_validator.py b/test/test_default_output_validator.py index a904ad8c..ed47bdbe 100644 --- a/test/test_default_output_validator.py +++ b/test/test_default_output_validator.py @@ -15,9 +15,7 @@ # Note: the python version isn't tested by default, because it's quite slow. DEFAULT_OUTPUT_VALIDATOR = ["default_output_validator.cpp"] -config.args.verbose = 2 -config.args.error = True -config.args.mark_set("verbose", "error") +config.args.add_if_not_set(config.ARGS("test_default_output_validator.py", verbose=2, error=True)) # return list of (flags, ans, out, expected result) diff --git a/test/test_problem_yaml.py b/test/test_problem_yaml.py index 8107d2ea..31da03e9 100644 --- a/test/test_problem_yaml.py +++ b/test/test_problem_yaml.py @@ -9,9 +9,7 @@ RUN_DIR = Path.cwd().absolute() -config.args.verbose = 2 -config.args.error = True -config.args.mark_set("verbose", "error") +config.args.add_if_not_set(config.ARGS("test_problem_yaml.py", verbose=2, error=True)) # return list of {yaml: {...}, ...} documents