Skip to content

Commit 62bea9e

Browse files
committed
gate async behind --async
1 parent 87ddb98 commit 62bea9e

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def parse_args() -> Namespace:
9696
)
9797
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
9898
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
99+
parser.add_argument(
100+
"--async", default=False, action="store_true", help="Enable optimization of async functions. By default, async functions are excluded from optimization."
101+
)
99102

100103
args, unknown_args = parser.parse_known_args()
101104
sys.argv[:] = [sys.argv[0], *unknown_args]

codeflash/discovery/functions_to_optimize.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def get_functions_to_optimize(
179179
project_root: Path,
180180
module_root: Path,
181181
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
182+
enable_async: bool = False,
182183
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
183184
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
184185
"Only one of optimize_all, replay_test, or file should be provided"
@@ -232,7 +233,7 @@ def get_functions_to_optimize(
232233
ph("cli-optimizing-git-diff")
233234
functions = get_functions_within_git_diff(uncommitted_changes=False)
234235
filtered_modified_functions, functions_count = filter_functions(
235-
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
236+
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions, enable_async=enable_async
236237
)
237238

238239
logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
@@ -568,6 +569,7 @@ def filter_functions(
568569
module_root: Path,
569570
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
570571
disable_logs: bool = False, # noqa: FBT001, FBT002
572+
enable_async: bool = False,
571573
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
572574
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
573575
blocklist_funcs = get_blocklisted_functions()
@@ -587,6 +589,7 @@ def filter_functions(
587589
submodule_ignored_paths_count: int = 0
588590
blocklist_funcs_removed_count: int = 0
589591
previous_checkpoint_functions_removed_count: int = 0
592+
async_functions_removed_count: int = 0
590593
tests_root_str = str(tests_root)
591594
module_root_str = str(module_root)
592595

@@ -642,6 +645,15 @@ def filter_functions(
642645
functions_tmp.append(function)
643646
_functions = functions_tmp
644647

648+
if not enable_async:
649+
functions_tmp = []
650+
for function in _functions:
651+
if function.is_async:
652+
async_functions_removed_count += 1
653+
continue
654+
functions_tmp.append(function)
655+
_functions = functions_tmp
656+
645657
filtered_modified_functions[file_path] = _functions
646658
functions_count += len(_functions)
647659

@@ -655,6 +667,7 @@ def filter_functions(
655667
"Files from ignored submodules": (submodule_ignored_paths_count, "bright_black"),
656668
"Blocklisted functions removed": (blocklist_funcs_removed_count, "bright_red"),
657669
"Functions skipped from checkpoint": (previous_checkpoint_functions_removed_count, "green"),
670+
"Async functions removed": (async_functions_removed_count, "bright_magenta"),
658671
}
659672
tree = Tree(Text("Ignored functions and files", style="bold"))
660673
for label, (count, color) in log_info.items():

codeflash/optimization/optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]
134134
project_root=self.args.project_root,
135135
module_root=self.args.module_root,
136136
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
137+
enable_async=getattr(self.args, "async", False),
137138
)
138139

139140
def create_function_optimizer(

tests/test_async_function_discovery.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def sync_method(self):
236236
ignore_paths=[],
237237
project_root=file_path.parent,
238238
module_root=file_path.parent,
239+
enable_async=True,
239240
)
240241

241242
assert functions_count == 4
@@ -249,6 +250,56 @@ def sync_method(self):
249250
assert "async_func_two" not in function_names
250251

251252

253+
def test_no_async_functions_finding(temp_dir):
254+
mixed_code = """
255+
async def async_func_one():
256+
return await operation_one()
257+
258+
def sync_func_one():
259+
return operation_one()
260+
261+
async def async_func_two():
262+
print("no return")
263+
264+
class MixedClass:
265+
async def async_method(self):
266+
return await self.operation()
267+
268+
def sync_method(self):
269+
return self.operation()
270+
"""
271+
272+
file_path = temp_dir / "test_file.py"
273+
file_path.write_text(mixed_code)
274+
275+
test_config = TestConfig(
276+
tests_root="tests",
277+
project_root_path=".",
278+
test_framework="pytest",
279+
tests_project_rootdir=Path()
280+
)
281+
282+
functions, functions_count, _ = get_functions_to_optimize(
283+
optimize_all=None,
284+
replay_test=None,
285+
file=file_path,
286+
only_get_this_function=None,
287+
test_cfg=test_config,
288+
ignore_paths=[],
289+
project_root=file_path.parent,
290+
module_root=file_path.parent,
291+
enable_async=False,
292+
)
293+
294+
assert functions_count == 2
295+
296+
function_names = [fn.function_name for fn in functions[file_path]]
297+
assert "sync_func_one" in function_names
298+
assert "sync_method" in function_names
299+
assert "async_func_one" not in function_names
300+
assert "async_method" not in function_names
301+
302+
252303
def test_async_function_parents(temp_dir):
253304
complex_structure = """
254305
class OuterClass:

0 commit comments

Comments
 (0)