diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 5edff57a0..fe789e621 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -8,6 +8,7 @@ from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions from codeflash.cli_cmds.console import logger from codeflash.code_utils import env_utils +from codeflash.code_utils.code_utils import exit_with_message from codeflash.code_utils.config_parser import parse_config_file from codeflash.version import __version__ as version @@ -42,7 +43,7 @@ def parse_args() -> Namespace: ) parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest") parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.") - parser.add_argument("--replay-test", type=str, help="Path to replay test to optimize functions from") + parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from") parser.add_argument( "--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally." ) @@ -83,25 +84,22 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace: sys.exit() if not check_running_in_git_repo(module_root=args.module_root): if not confirm_proceeding_with_no_git_repo(): - logger.critical("No git repository detected and user aborted run. Exiting...") - sys.exit(1) + exit_with_message("No git repository detected and user aborted run. Exiting...", error_on_exit=True) args.no_pr = True if args.function and not args.file: - logger.error("If you specify a --function, you must specify the --file it is in") - sys.exit(1) + exit_with_message("If you specify a --function, you must specify the --file it is in", error_on_exit=True) if args.file: if not Path(args.file).exists(): - logger.error(f"File {args.file} does not exist") - sys.exit(1) + exit_with_message(f"File {args.file} does not exist", error_on_exit=True) args.file = Path(args.file).resolve() if not args.no_pr: owner, repo = get_repo_owner_and_name() require_github_app_or_exit(owner, repo) if args.replay_test: - if not Path(args.replay_test).is_file(): - logger.error(f"Replay test file {args.replay_test} does not exist") - sys.exit(1) - args.replay_test = Path(args.replay_test).resolve() + for test_path in args.replay_test: + if not Path(test_path).is_file(): + exit_with_message(f"Replay test file {test_path} does not exist", error_on_exit=True) + args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test] return args @@ -110,8 +108,7 @@ def process_pyproject_config(args: Namespace) -> Namespace: try: pyproject_config, pyproject_file_path = parse_config_file(args.config_file) except ValueError as e: - logger.error(e) - sys.exit(1) + exit_with_message(f"Error parsing config file: {e}", error_on_exit=True) supported_keys = [ "module_root", "tests_root", @@ -206,8 +203,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace: ) apologize_and_exit() if not args.no_pr and not check_and_push_branch(git_repo): - logger.critical("❌ Branch is not pushed. Exiting...") - sys.exit(1) + exit_with_message("Branch is not pushed...", error_on_exit=True) owner, repo = get_repo_owner_and_name(git_repo) if not args.no_pr: require_github_app_or_exit(owner, repo) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 22cd817d4..82a5b9791 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -34,7 +34,7 @@ def custom_addopts() -> None: # Backup original addopts original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "") # nothing to do if no addopts present - if original_addopts != "": + if original_addopts != "" and isinstance(original_addopts, list): original_addopts = [x.strip() for x in original_addopts] non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ") non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""] diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 31f930e09..792a9fcff 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -150,7 +150,7 @@ def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: def get_functions_to_optimize( optimize_all: str | None, - replay_test: str | None, + replay_test: list[Path] | None, file: Path | None, only_get_this_function: str | None, test_cfg: TestConfig, @@ -169,7 +169,7 @@ def get_functions_to_optimize( logger.info("Finding all functions in the module '%s'…", optimize_all) console.rule() functions = get_all_files_and_functions(Path(optimize_all)) - elif replay_test is not None: + elif replay_test: functions = get_all_replay_test_functions( replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root ) @@ -271,9 +271,9 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt def get_all_replay_test_functions( - replay_test: Path, test_cfg: TestConfig, project_root_path: Path + replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path ) -> dict[Path, list[FunctionToOptimize]]: - function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) + function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test) # Get the absolute file paths for each function, excluding class name if present filtered_valid_functions = defaultdict(list) file_to_functions_map = defaultdict(list) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 9fa1f3290..fa8fdc88a 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -379,7 +379,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911 del arguments_copy["self"] local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL) sys.setrecursionlimit(original_recursion_limit) - except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError): + except Exception: # we retry with dill if pickle fails. It's slower but more comprehensive try: sys.setrecursionlimit(10000) # Ensure limit is high for dill too @@ -390,7 +390,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911 ) sys.setrecursionlimit(original_recursion_limit) - except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError): + except Exception: self.function_count[function_qualified_name] -= 1 return