1616import tomlkit
1717from git import InvalidGitRepositoryError , Repo
1818from pydantic .dataclasses import dataclass
19+ from rich .prompt import Confirm
1920
2021from codeflash .api .cfapi import is_github_app_installed_on_repo
2122from codeflash .cli_cmds .cli_common import apologize_and_exit , inquirer_wrapper , inquirer_wrapper_path
4546 f"{ LF } "
4647)
4748
49+
4850@dataclass (frozen = True )
4951class SetupInfo :
5052 module_root : str
@@ -70,7 +72,6 @@ def init_codeflash() -> None:
7072 did_add_new_key = prompt_api_key ()
7173
7274 if should_modify_pyproject_toml ():
73-
7475 setup_info : SetupInfo = collect_setup_info ()
7576
7677 configure_pyproject_toml (setup_info )
@@ -83,7 +84,6 @@ def init_codeflash() -> None:
8384 if "setup_info" in locals ():
8485 module_string = f" you selected ({ setup_info .module_root } )"
8586
86-
8787 click .echo (
8888 f"{ LF } "
8989 f"⚡️ Codeflash is now set up! You can now run:{ LF } "
@@ -125,11 +125,13 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
125125 bubble_sort_path , bubble_sort_test_path = create_bubble_sort_file_and_test (args )
126126 run_end_to_end_test (args , bubble_sort_path , bubble_sort_test_path )
127127
128+
128129def should_modify_pyproject_toml () -> bool :
129130 """Check if the current directory contains a valid pyproject.toml file with codeflash config
130131 If it does, ask the user if they want to re-configure it.
131132 """
132133 from rich .prompt import Confirm
134+
133135 pyproject_toml_path = Path .cwd () / "pyproject.toml"
134136 if not pyproject_toml_path .exists ():
135137 return True
@@ -144,7 +146,9 @@ def should_modify_pyproject_toml() -> bool:
144146 return True
145147
146148 create_toml = Confirm .ask (
147- "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?" , default = False , show_default = True
149+ "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?" ,
150+ default = False ,
151+ show_default = True ,
148152 )
149153 return create_toml
150154
@@ -160,7 +164,18 @@ def collect_setup_info() -> SetupInfo:
160164 # Check for the existence of pyproject.toml or setup.py
161165 project_name = check_for_toml_or_setup_file ()
162166
163- ignore_subdirs = ["venv" , "node_modules" , "dist" , "build" , "build_temp" , "build_scripts" , "env" , "logs" , "tmp" , "__pycache__" ]
167+ ignore_subdirs = [
168+ "venv" ,
169+ "node_modules" ,
170+ "dist" ,
171+ "build" ,
172+ "build_temp" ,
173+ "build_scripts" ,
174+ "env" ,
175+ "logs" ,
176+ "tmp" ,
177+ "__pycache__" ,
178+ ]
164179 valid_subdirs = [
165180 d for d in next (os .walk ("." ))[1 ] if not d .startswith ("." ) and not d .startswith ("__" ) and d not in ignore_subdirs
166181 ]
@@ -225,7 +240,7 @@ def collect_setup_info() -> SetupInfo:
225240 else :
226241 apologize_and_exit ()
227242 else :
228- tests_root = Path (curdir ) / Path (cast (str , tests_root_answer ))
243+ tests_root = Path (curdir ) / Path (cast (" str" , tests_root_answer ))
229244 tests_root = tests_root .relative_to (curdir )
230245 ph ("cli-tests-root-provided" )
231246
@@ -262,13 +277,13 @@ def collect_setup_info() -> SetupInfo:
262277 benchmarks_options .append (create_benchmarks_option )
263278 benchmarks_options .append (custom_dir_option )
264279
265-
266280 benchmarks_answer = inquirer_wrapper (
267281 inquirer .list_input ,
268282 message = "Where are your performance benchmarks located? (benchmarks must be a sub directory of your tests root directory)" ,
269283 choices = benchmarks_options ,
270284 default = (
271- default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options [0 ]),
285+ default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options [0 ]
286+ ),
272287 )
273288
274289 if benchmarks_answer == create_benchmarks_option :
@@ -288,7 +303,7 @@ def collect_setup_info() -> SetupInfo:
288303 elif benchmarks_answer == no_benchmarks_option :
289304 benchmarks_root = None
290305 else :
291- benchmarks_root = tests_root / Path (cast (str , benchmarks_answer ))
306+ benchmarks_root = tests_root / Path (cast (" str" , benchmarks_answer ))
292307
293308 # TODO: Implement other benchmark framework options
294309 # if benchmarks_root:
@@ -304,7 +319,6 @@ def collect_setup_info() -> SetupInfo:
304319 # carousel=True,
305320 # )
306321
307-
308322 formatter = inquirer_wrapper (
309323 inquirer .list_input ,
310324 message = "Which code formatter do you use?" ,
@@ -340,10 +354,10 @@ def collect_setup_info() -> SetupInfo:
340354 return SetupInfo (
341355 module_root = str (module_root ),
342356 tests_root = str (tests_root ),
343- benchmarks_root = str (benchmarks_root ) if benchmarks_root else None ,
344- test_framework = cast (str , test_framework ),
357+ benchmarks_root = str (benchmarks_root ) if benchmarks_root else None ,
358+ test_framework = cast (" str" , test_framework ),
345359 ignore_paths = ignore_paths ,
346- formatter = cast (str , formatter ),
360+ formatter = cast (" str" , formatter ),
347361 git_remote = str (git_remote ),
348362 )
349363
@@ -453,7 +467,7 @@ def check_for_toml_or_setup_file() -> str | None:
453467 click .echo ("⏩️ Skipping pyproject.toml creation." )
454468 apologize_and_exit ()
455469 click .echo ()
456- return cast (str , project_name )
470+ return cast (" str" , project_name )
457471
458472
459473def install_github_actions (override_formatter_check : bool = False ) -> None :
@@ -499,19 +513,22 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
499513 return
500514 workflows_path .mkdir (parents = True , exist_ok = True )
501515 from importlib .resources import files
516+
502517 benchmark_mode = False
503518 if "benchmarks_root" in config :
504519 benchmark_mode = inquirer_wrapper (
505520 inquirer .confirm ,
506521 message = "⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
507- " This will show the impact of Codeflash's suggested optimizations on your benchmarks" ,
522+ " This will show the impact of Codeflash's suggested optimizations on your benchmarks" ,
508523 default = True ,
509524 )
510525
511526 optimize_yml_content = (
512527 files ("codeflash" ).joinpath ("cli_cmds" , "workflows" , "codeflash-optimize.yaml" ).read_text (encoding = "utf-8" )
513528 )
514- materialized_optimize_yml_content = customize_codeflash_yaml_content (optimize_yml_content , config , git_root , benchmark_mode )
529+ materialized_optimize_yml_content = customize_codeflash_yaml_content (
530+ optimize_yml_content , config , git_root , benchmark_mode
531+ )
515532 with optimize_yaml_path .open ("w" , encoding = "utf8" ) as optimize_yml_file :
516533 optimize_yml_file .write (materialized_optimize_yml_content )
517534 click .echo (f"{ LF } ✅ Created GitHub action workflow at { optimize_yaml_path } { LF } " )
@@ -941,12 +958,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
941958
942959def ask_for_telemetry () -> bool :
943960 """Prompt the user to enable or disable telemetry."""
944- from rich .prompt import Confirm
945-
946- enable_telemetry = Confirm .ask (
961+ return Confirm .ask (
947962 "⚡️ Would you like to enable telemetry to help us improve the Codeflash experience?" ,
948963 default = True ,
949964 show_default = True ,
950965 )
951-
952- return enable_telemetry
0 commit comments