Skip to content

Commit a457aad

Browse files
committed
multiple replay tests specified
1 parent 36ce827 commit a457aad

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def parse_args() -> Namespace:
4242
)
4343
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
4444
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
45-
parser.add_argument("--replay-test", type=str, help="Path to replay test to optimize functions from")
45+
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
4646
parser.add_argument(
4747
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
4848
)
@@ -98,10 +98,11 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
9898
owner, repo = get_repo_owner_and_name()
9999
require_github_app_or_exit(owner, repo)
100100
if args.replay_test:
101-
if not Path(args.replay_test).is_file():
102-
logger.error(f"Replay test file {args.replay_test} does not exist")
103-
sys.exit(1)
104-
args.replay_test = Path(args.replay_test).resolve()
101+
for test_path in args.replay_test:
102+
if not Path(test_path).is_file():
103+
logger.error(f"Replay test file {test_path} does not exist")
104+
sys.exit(1)
105+
args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test]
105106

106107
return args
107108

codeflash/discovery/functions_to_optimize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
149149

150150
def get_functions_to_optimize(
151151
optimize_all: str | None,
152-
replay_test: str | None,
152+
replay_test: list[Path] | None,
153153
file: Path | None,
154154
only_get_this_function: str | None,
155155
test_cfg: TestConfig,
@@ -168,7 +168,7 @@ def get_functions_to_optimize(
168168
logger.info("Finding all functions in the module '%s'…", optimize_all)
169169
console.rule()
170170
functions = get_all_files_and_functions(Path(optimize_all))
171-
elif replay_test is not None:
171+
elif replay_test:
172172
functions = get_all_replay_test_functions(
173173
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
174174
)
@@ -268,9 +268,9 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
268268

269269

270270
def get_all_replay_test_functions(
271-
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
271+
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
272272
) -> dict[Path, list[FunctionToOptimize]]:
273-
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
273+
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test)
274274
# Get the absolute file paths for each function, excluding class name if present
275275
filtered_valid_functions = defaultdict(list)
276276
file_to_functions_map = defaultdict(list)

0 commit comments

Comments
 (0)