5050class SetupInfo :
5151 module_root : str
5252 tests_root : str
53+ benchmarks_root : str | None
5354 test_framework : str
5455 ignore_paths : list [str ]
5556 formatter : str
@@ -126,8 +127,7 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
126127 run_end_to_end_test (args , bubble_sort_path , bubble_sort_test_path )
127128
128129def should_modify_pyproject_toml () -> bool :
129- """
130- Check if the current directory contains a valid pyproject.toml file with codeflash config
130+ """Check if the current directory contains a valid pyproject.toml file with codeflash config
131131 If it does, ask the user if they want to re-configure it.
132132 """
133133 from rich .prompt import Confirm
@@ -136,7 +136,7 @@ def should_modify_pyproject_toml() -> bool:
136136 return True
137137 try :
138138 config , config_file_path = parse_config_file (pyproject_toml_path )
139- except Exception as e :
139+ except Exception :
140140 return True
141141
142142 if "module_root" not in config or config ["module_root" ] is None or not Path (config ["module_root" ]).is_dir ():
@@ -145,7 +145,7 @@ def should_modify_pyproject_toml() -> bool:
145145 return True
146146
147147 create_toml = Confirm .ask (
148- f "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?" , default = False , show_default = True
148+ "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?" , default = False , show_default = True
149149 )
150150 return create_toml
151151
@@ -245,6 +245,66 @@ def collect_setup_info() -> SetupInfo:
245245
246246 ph ("cli-test-framework-provided" , {"test_framework" : test_framework })
247247
248+ # Get benchmarks root directory
249+ default_benchmarks_subdir = "benchmarks"
250+ create_benchmarks_option = f"okay, create a { default_benchmarks_subdir } { os .path .sep } directory for me!"
251+ no_benchmarks_option = "I don't need benchmarks"
252+
253+ # Check if benchmarks directory exists inside tests directory
254+ tests_subdirs = []
255+ if tests_root .exists ():
256+ tests_subdirs = [d .name for d in tests_root .iterdir () if d .is_dir () and not d .name .startswith ("." )]
257+
258+ benchmarks_options = []
259+ if default_benchmarks_subdir in tests_subdirs :
260+ benchmarks_options .append (default_benchmarks_subdir )
261+ benchmarks_options .extend ([d for d in tests_subdirs if d != default_benchmarks_subdir ])
262+ benchmarks_options .append (create_benchmarks_option )
263+ benchmarks_options .append (custom_dir_option )
264+ benchmarks_options .append (no_benchmarks_option )
265+
266+ benchmarks_answer = inquirer_wrapper (
267+ inquirer .list_input ,
268+ message = "Where are your benchmarks located? (benchmarks must be a sub directory of your tests root directory)" ,
269+ choices = benchmarks_options ,
270+ default = (
271+ default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options [0 ]),
272+ )
273+
274+ if benchmarks_answer == create_benchmarks_option :
275+ benchmarks_root = tests_root / default_benchmarks_subdir
276+ benchmarks_root .mkdir (exist_ok = True )
277+ click .echo (f"✅ Created directory { benchmarks_root } { os .path .sep } { LF } " )
278+ elif benchmarks_answer == custom_dir_option :
279+ custom_benchmarks_answer = inquirer_wrapper_path (
280+ "path" ,
281+ message = f"Enter the path to your benchmarks directory inside { tests_root } { os .path .sep } " ,
282+ path_type = inquirer .Path .DIRECTORY ,
283+ )
284+ if custom_benchmarks_answer :
285+ benchmarks_root = tests_root / Path (custom_benchmarks_answer ["path" ])
286+ else :
287+ apologize_and_exit ()
288+ elif benchmarks_answer == no_benchmarks_option :
289+ benchmarks_root = None
290+ else :
291+ benchmarks_root = tests_root / Path (cast (str , benchmarks_answer ))
292+
293+ # TODO: Implement other benchmark framework options
294+ # if benchmarks_root:
295+ # benchmarks_root = benchmarks_root.relative_to(curdir)
296+ #
297+ # # Ask about benchmark framework
298+ # benchmark_framework_options = ["pytest-benchmark", "asv (Airspeed Velocity)", "custom/other"]
299+ # benchmark_framework = inquirer_wrapper(
300+ # inquirer.list_input,
301+ # message="Which benchmark framework do you use?",
302+ # choices=benchmark_framework_options,
303+ # default=benchmark_framework_options[0],
304+ # carousel=True,
305+ # )
306+
307+
248308 formatter = inquirer_wrapper (
249309 inquirer .list_input ,
250310 message = "Which code formatter do you use?" ,
@@ -280,6 +340,7 @@ def collect_setup_info() -> SetupInfo:
280340 return SetupInfo (
281341 module_root = str (module_root ),
282342 tests_root = str (tests_root ),
343+ benchmarks_root = str (benchmarks_root ) if benchmarks_root else None ,
283344 test_framework = cast (str , test_framework ),
284345 ignore_paths = ignore_paths ,
285346 formatter = cast (str , formatter ),
@@ -438,11 +499,19 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
438499 return
439500 workflows_path .mkdir (parents = True , exist_ok = True )
440501 from importlib .resources import files
502+ benchmark_mode = False
503+ if "benchmarks_root" in config :
504+ benchmark_mode = inquirer_wrapper (
505+ inquirer .confirm ,
506+ message = "⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
507+ " This will show the impact of Codeflash's suggested optimizations on your benchmarks" ,
508+ default = True ,
509+ )
441510
442511 optimize_yml_content = (
443512 files ("codeflash" ).joinpath ("cli_cmds" , "workflows" , "codeflash-optimize.yaml" ).read_text (encoding = "utf-8" )
444513 )
445- materialized_optimize_yml_content = customize_codeflash_yaml_content (optimize_yml_content , config , git_root )
514+ materialized_optimize_yml_content = customize_codeflash_yaml_content (optimize_yml_content , config , git_root , benchmark_mode )
446515 with optimize_yaml_path .open ("w" , encoding = "utf8" ) as optimize_yml_file :
447516 optimize_yml_file .write (materialized_optimize_yml_content )
448517 click .echo (f"{ LF } ✅ Created GitHub action workflow at { optimize_yaml_path } { LF } " )
@@ -557,7 +626,7 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
557626
558627
559628def customize_codeflash_yaml_content (
560- optimize_yml_content : str , config : tuple [dict [str , Any ], Path ], git_root : Path
629+ optimize_yml_content : str , config : tuple [dict [str , Any ], Path ], git_root : Path , benchmark_mode : bool = False
561630) -> str :
562631 module_path = str (Path (config ["module_root" ]).relative_to (git_root ) / "**" )
563632 optimize_yml_content = optimize_yml_content .replace ("{{ codeflash_module_path }}" , module_path )
@@ -588,6 +657,9 @@ def customize_codeflash_yaml_content(
588657
589658 # Add codeflash command
590659 codeflash_cmd = get_codeflash_github_action_command (dep_manager )
660+
661+ if benchmark_mode :
662+ codeflash_cmd += " --benchmark"
591663 return optimize_yml_content .replace ("{{ codeflash_command }}" , codeflash_cmd )
592664
593665
@@ -609,6 +681,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
609681 codeflash_section ["module-root" ] = setup_info .module_root
610682 codeflash_section ["tests-root" ] = setup_info .tests_root
611683 codeflash_section ["test-framework" ] = setup_info .test_framework
684+ codeflash_section ["benchmarks-root" ] = setup_info .benchmarks_root if setup_info .benchmarks_root else ""
612685 codeflash_section ["ignore-paths" ] = setup_info .ignore_paths
613686 if setup_info .git_remote not in ["" , "origin" ]:
614687 codeflash_section ["git-remote" ] = setup_info .git_remote
0 commit comments