Skip to content

Commit 28fd746

Browse files
committed
cmd init changes
1 parent b70c4c9 commit 28fd746

File tree

2 files changed

+94
-19
lines changed

2 files changed

+94
-19
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
class SetupInfo:
5151
module_root: str
5252
tests_root: str
53+
benchmarks_root: str | None
5354
test_framework: str
5455
ignore_paths: list[str]
5556
formatter: str
@@ -126,8 +127,7 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
126127
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
127128

128129
def should_modify_pyproject_toml() -> bool:
129-
"""
130-
Check if the current directory contains a valid pyproject.toml file with codeflash config
130+
"""Check if the current directory contains a valid pyproject.toml file with codeflash config
131131
If it does, ask the user if they want to re-configure it.
132132
"""
133133
from rich.prompt import Confirm
@@ -136,7 +136,7 @@ def should_modify_pyproject_toml() -> bool:
136136
return True
137137
try:
138138
config, config_file_path = parse_config_file(pyproject_toml_path)
139-
except Exception as e:
139+
except Exception:
140140
return True
141141

142142
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
@@ -145,7 +145,7 @@ def should_modify_pyproject_toml() -> bool:
145145
return True
146146

147147
create_toml = Confirm.ask(
148-
f"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True
148+
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True
149149
)
150150
return create_toml
151151

@@ -245,6 +245,66 @@ def collect_setup_info() -> SetupInfo:
245245

246246
ph("cli-test-framework-provided", {"test_framework": test_framework})
247247

248+
# Get benchmarks root directory
249+
default_benchmarks_subdir = "benchmarks"
250+
create_benchmarks_option = f"okay, create a {default_benchmarks_subdir}{os.path.sep} directory for me!"
251+
no_benchmarks_option = "I don't need benchmarks"
252+
253+
# Check if benchmarks directory exists inside tests directory
254+
tests_subdirs = []
255+
if tests_root.exists():
256+
tests_subdirs = [d.name for d in tests_root.iterdir() if d.is_dir() and not d.name.startswith(".")]
257+
258+
benchmarks_options = []
259+
if default_benchmarks_subdir in tests_subdirs:
260+
benchmarks_options.append(default_benchmarks_subdir)
261+
benchmarks_options.extend([d for d in tests_subdirs if d != default_benchmarks_subdir])
262+
benchmarks_options.append(create_benchmarks_option)
263+
benchmarks_options.append(custom_dir_option)
264+
benchmarks_options.append(no_benchmarks_option)
265+
266+
benchmarks_answer = inquirer_wrapper(
267+
inquirer.list_input,
268+
message="Where are your benchmarks located? (benchmarks must be a sub directory of your tests root directory)",
269+
choices=benchmarks_options,
270+
default=(
271+
default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options[0]),
272+
)
273+
274+
if benchmarks_answer == create_benchmarks_option:
275+
benchmarks_root = tests_root / default_benchmarks_subdir
276+
benchmarks_root.mkdir(exist_ok=True)
277+
click.echo(f"✅ Created directory {benchmarks_root}{os.path.sep}{LF}")
278+
elif benchmarks_answer == custom_dir_option:
279+
custom_benchmarks_answer = inquirer_wrapper_path(
280+
"path",
281+
message=f"Enter the path to your benchmarks directory inside {tests_root}{os.path.sep} ",
282+
path_type=inquirer.Path.DIRECTORY,
283+
)
284+
if custom_benchmarks_answer:
285+
benchmarks_root = tests_root / Path(custom_benchmarks_answer["path"])
286+
else:
287+
apologize_and_exit()
288+
elif benchmarks_answer == no_benchmarks_option:
289+
benchmarks_root = None
290+
else:
291+
benchmarks_root = tests_root / Path(cast(str, benchmarks_answer))
292+
293+
# TODO: Implement other benchmark framework options
294+
# if benchmarks_root:
295+
# benchmarks_root = benchmarks_root.relative_to(curdir)
296+
#
297+
# # Ask about benchmark framework
298+
# benchmark_framework_options = ["pytest-benchmark", "asv (Airspeed Velocity)", "custom/other"]
299+
# benchmark_framework = inquirer_wrapper(
300+
# inquirer.list_input,
301+
# message="Which benchmark framework do you use?",
302+
# choices=benchmark_framework_options,
303+
# default=benchmark_framework_options[0],
304+
# carousel=True,
305+
# )
306+
307+
248308
formatter = inquirer_wrapper(
249309
inquirer.list_input,
250310
message="Which code formatter do you use?",
@@ -280,6 +340,7 @@ def collect_setup_info() -> SetupInfo:
280340
return SetupInfo(
281341
module_root=str(module_root),
282342
tests_root=str(tests_root),
343+
benchmarks_root = str(benchmarks_root) if benchmarks_root else None,
283344
test_framework=cast(str, test_framework),
284345
ignore_paths=ignore_paths,
285346
formatter=cast(str, formatter),
@@ -438,11 +499,19 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
438499
return
439500
workflows_path.mkdir(parents=True, exist_ok=True)
440501
from importlib.resources import files
502+
benchmark_mode = False
503+
if "benchmarks_root" in config:
504+
benchmark_mode = inquirer_wrapper(
505+
inquirer.confirm,
506+
message="⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
507+
" This will show the impact of Codeflash's suggested optimizations on your benchmarks",
508+
default=True,
509+
)
441510

442511
optimize_yml_content = (
443512
files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8")
444513
)
445-
materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root)
514+
materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
446515
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
447516
optimize_yml_file.write(materialized_optimize_yml_content)
448517
click.echo(f"{LF}✅ Created GitHub action workflow at {optimize_yaml_path}{LF}")
@@ -557,7 +626,7 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
557626

558627

559628
def customize_codeflash_yaml_content(
560-
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path
629+
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
561630
) -> str:
562631
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
563632
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
@@ -588,6 +657,9 @@ def customize_codeflash_yaml_content(
588657

589658
# Add codeflash command
590659
codeflash_cmd = get_codeflash_github_action_command(dep_manager)
660+
661+
if benchmark_mode:
662+
codeflash_cmd += " --benchmark"
591663
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
592664

593665

@@ -609,6 +681,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
609681
codeflash_section["module-root"] = setup_info.module_root
610682
codeflash_section["tests-root"] = setup_info.tests_root
611683
codeflash_section["test-framework"] = setup_info.test_framework
684+
codeflash_section["benchmarks-root"] = setup_info.benchmarks_root if setup_info.benchmarks_root else ""
612685
codeflash_section["ignore-paths"] = setup_info.ignore_paths
613686
if setup_info.git_remote not in ["", "origin"]:
614687
codeflash_section["git-remote"] = setup_info.git_remote

codeflash/models/models.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -537,20 +537,22 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
537537
return tree
538538

539539
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
540-
541-
usable_runtime = defaultdict(list)
542540
for result in self.test_results:
543-
if result.did_pass:
544-
if not result.runtime:
545-
msg = (
546-
f"Ignoring test case that passed but had no runtime -> {result.id}, "
547-
f"Loop # {result.loop_index}, Test Type: {result.test_type}, "
548-
f"Verification Type: {result.verification_type}"
549-
)
550-
logger.debug(msg)
551-
else:
552-
usable_runtime[result.id].append(result.runtime)
553-
return usable_runtime
541+
if result.did_pass and not result.runtime:
542+
msg = (
543+
f"Ignoring test case that passed but had no runtime -> {result.id}, "
544+
f"Loop # {result.loop_index}, Test Type: {result.test_type}, "
545+
f"Verification Type: {result.verification_type}"
546+
)
547+
logger.debug(msg)
548+
549+
usable_runtimes = [
550+
(result.id, result.runtime) for result in self.test_results if result.did_pass and result.runtime
551+
]
552+
return {
553+
usable_id: [runtime[1] for runtime in usable_runtimes if runtime[0] == usable_id]
554+
for usable_id in {runtime[0] for runtime in usable_runtimes}
555+
}
554556

555557
def total_passed_runtime(self) -> int:
556558
"""Calculate the sum of runtimes of all test cases that passed.

0 commit comments

Comments
 (0)