Skip to content

Commit 55a48eb

Browse files
Merge pull request #316 from codeflash-ai/can-specify-multiple-replay-tests
Replay tests and tracer improvments
2 parents 6e2d21f + 38b69ca commit 55a48eb

File tree

4 files changed

+18
-22
lines changed

4 files changed

+18
-22
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
99
from codeflash.cli_cmds.console import logger
1010
from codeflash.code_utils import env_utils
11+
from codeflash.code_utils.code_utils import exit_with_message
1112
from codeflash.code_utils.config_parser import parse_config_file
1213
from codeflash.version import __version__ as version
1314

@@ -42,7 +43,7 @@ def parse_args() -> Namespace:
4243
)
4344
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
4445
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
45-
parser.add_argument("--replay-test", type=str, help="Path to replay test to optimize functions from")
46+
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
4647
parser.add_argument(
4748
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
4849
)
@@ -83,25 +84,22 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
8384
sys.exit()
8485
if not check_running_in_git_repo(module_root=args.module_root):
8586
if not confirm_proceeding_with_no_git_repo():
86-
logger.critical("No git repository detected and user aborted run. Exiting...")
87-
sys.exit(1)
87+
exit_with_message("No git repository detected and user aborted run. Exiting...", error_on_exit=True)
8888
args.no_pr = True
8989
if args.function and not args.file:
90-
logger.error("If you specify a --function, you must specify the --file it is in")
91-
sys.exit(1)
90+
exit_with_message("If you specify a --function, you must specify the --file it is in", error_on_exit=True)
9291
if args.file:
9392
if not Path(args.file).exists():
94-
logger.error(f"File {args.file} does not exist")
95-
sys.exit(1)
93+
exit_with_message(f"File {args.file} does not exist", error_on_exit=True)
9694
args.file = Path(args.file).resolve()
9795
if not args.no_pr:
9896
owner, repo = get_repo_owner_and_name()
9997
require_github_app_or_exit(owner, repo)
10098
if args.replay_test:
101-
if not Path(args.replay_test).is_file():
102-
logger.error(f"Replay test file {args.replay_test} does not exist")
103-
sys.exit(1)
104-
args.replay_test = Path(args.replay_test).resolve()
99+
for test_path in args.replay_test:
100+
if not Path(test_path).is_file():
101+
exit_with_message(f"Replay test file {test_path} does not exist", error_on_exit=True)
102+
args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test]
105103

106104
return args
107105

@@ -110,8 +108,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
110108
try:
111109
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
112110
except ValueError as e:
113-
logger.error(e)
114-
sys.exit(1)
111+
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
115112
supported_keys = [
116113
"module_root",
117114
"tests_root",
@@ -206,8 +203,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
206203
)
207204
apologize_and_exit()
208205
if not args.no_pr and not check_and_push_branch(git_repo):
209-
logger.critical("❌ Branch is not pushed. Exiting...")
210-
sys.exit(1)
206+
exit_with_message("Branch is not pushed...", error_on_exit=True)
211207
owner, repo = get_repo_owner_and_name(git_repo)
212208
if not args.no_pr:
213209
require_github_app_or_exit(owner, repo)

codeflash/code_utils/code_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def custom_addopts() -> None:
3434
# Backup original addopts
3535
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
3636
# nothing to do if no addopts present
37-
if original_addopts != "":
37+
if original_addopts != "" and isinstance(original_addopts, list):
3838
original_addopts = [x.strip() for x in original_addopts]
3939
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
4040
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]

codeflash/discovery/functions_to_optimize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
150150

151151
def get_functions_to_optimize(
152152
optimize_all: str | None,
153-
replay_test: str | None,
153+
replay_test: list[Path] | None,
154154
file: Path | None,
155155
only_get_this_function: str | None,
156156
test_cfg: TestConfig,
@@ -169,7 +169,7 @@ def get_functions_to_optimize(
169169
logger.info("Finding all functions in the module '%s'…", optimize_all)
170170
console.rule()
171171
functions = get_all_files_and_functions(Path(optimize_all))
172-
elif replay_test is not None:
172+
elif replay_test:
173173
functions = get_all_replay_test_functions(
174174
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
175175
)
@@ -271,9 +271,9 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
271271

272272

273273
def get_all_replay_test_functions(
274-
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
274+
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
275275
) -> dict[Path, list[FunctionToOptimize]]:
276-
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
276+
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test)
277277
# Get the absolute file paths for each function, excluding class name if present
278278
filtered_valid_functions = defaultdict(list)
279279
file_to_functions_map = defaultdict(list)

codeflash/tracer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
379379
del arguments_copy["self"]
380380
local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL)
381381
sys.setrecursionlimit(original_recursion_limit)
382-
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
382+
except Exception:
383383
# we retry with dill if pickle fails. It's slower but more comprehensive
384384
try:
385385
sys.setrecursionlimit(10000) # Ensure limit is high for dill too
@@ -390,7 +390,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
390390
)
391391
sys.setrecursionlimit(original_recursion_limit)
392392

393-
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
393+
except Exception:
394394
self.function_count[function_qualified_name] -= 1
395395
return
396396

0 commit comments

Comments
 (0)