Skip to content

Commit 0a296e2

Browse files
committed
more unit tests
1 parent 8039ec4 commit 0a296e2

File tree

3 files changed

+283
-1
lines changed

3 files changed

+283
-1
lines changed

codeflash/code_utils/code_replacer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
265265
self.new_functions.append(node)
266266
return False
267267

268+
268269
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
269270
if self.current_class:
270271
return False # If already in a class, do not recurse deeper
@@ -315,6 +316,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
315316

316317
return updated_node
317318

319+
318320
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
319321
if self.current_class:
320322
return False # If already in a class, do not recurse deeper

tests/test_code_context_extractor.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from codeflash.models.models import FunctionParent
1313
from codeflash.optimization.optimizer import Optimizer
1414
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
15-
from codeflash.code_utils.code_extractor import add_global_assignments
15+
from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector
1616

1717

1818
class HelperClass:
@@ -2482,3 +2482,148 @@ def test_circular_deps():
24822482
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
24832483

24842484
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
2485+
def test_global_assignment_collector_with_async_function():
2486+
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
2487+
import libcst as cst
2488+
2489+
source_code = """
2490+
# Global assignment
2491+
GLOBAL_VAR = "global_value"
2492+
OTHER_GLOBAL = 42
2493+
2494+
async def async_function():
2495+
# This should not be collected (inside async function)
2496+
local_var = "local_value"
2497+
INNER_ASSIGNMENT = "should_not_be_global"
2498+
return local_var
2499+
2500+
# Another global assignment
2501+
ANOTHER_GLOBAL = "another_global"
2502+
"""
2503+
2504+
tree = cst.parse_module(source_code)
2505+
collector = GlobalAssignmentCollector()
2506+
tree.visit(collector)
2507+
2508+
# Should collect global assignments but not the ones inside async function
2509+
assert len(collector.assignments) == 3
2510+
assert "GLOBAL_VAR" in collector.assignments
2511+
assert "OTHER_GLOBAL" in collector.assignments
2512+
assert "ANOTHER_GLOBAL" in collector.assignments
2513+
2514+
# Should not collect assignments from inside async function
2515+
assert "local_var" not in collector.assignments
2516+
assert "INNER_ASSIGNMENT" not in collector.assignments
2517+
2518+
# Verify assignment order
2519+
expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"]
2520+
assert collector.assignment_order == expected_order
2521+
2522+
2523+
def test_global_assignment_collector_nested_async_functions():
2524+
"""Test GlobalAssignmentCollector handles nested async functions correctly."""
2525+
import libcst as cst
2526+
2527+
source_code = """
2528+
# Global assignment
2529+
CONFIG = {"key": "value"}
2530+
2531+
def sync_function():
2532+
# Inside sync function - should not be collected
2533+
sync_local = "sync"
2534+
2535+
async def nested_async():
2536+
# Inside nested async function - should not be collected
2537+
nested_var = "nested"
2538+
return nested_var
2539+
2540+
return sync_local
2541+
2542+
async def async_function():
2543+
# Inside async function - should not be collected
2544+
async_local = "async"
2545+
2546+
def nested_sync():
2547+
# Inside nested function - should not be collected
2548+
deeply_nested = "deep"
2549+
return deeply_nested
2550+
2551+
return async_local
2552+
2553+
# Another global assignment
2554+
FINAL_GLOBAL = "final"
2555+
"""
2556+
2557+
tree = cst.parse_module(source_code)
2558+
collector = GlobalAssignmentCollector()
2559+
tree.visit(collector)
2560+
2561+
# Should only collect global-level assignments
2562+
assert len(collector.assignments) == 2
2563+
assert "CONFIG" in collector.assignments
2564+
assert "FINAL_GLOBAL" in collector.assignments
2565+
2566+
# Should not collect any assignments from inside functions
2567+
assert "sync_local" not in collector.assignments
2568+
assert "nested_var" not in collector.assignments
2569+
assert "async_local" not in collector.assignments
2570+
assert "deeply_nested" not in collector.assignments
2571+
2572+
2573+
def test_global_assignment_collector_mixed_async_sync_with_classes():
2574+
"""Test GlobalAssignmentCollector with async functions, sync functions, and classes."""
2575+
import libcst as cst
2576+
2577+
source_code = """
2578+
# Global assignments
2579+
GLOBAL_CONSTANT = "constant"
2580+
2581+
class TestClass:
2582+
# Class-level assignment - should not be collected
2583+
class_var = "class_value"
2584+
2585+
def sync_method(self):
2586+
# Method assignment - should not be collected
2587+
method_var = "method"
2588+
return method_var
2589+
2590+
async def async_method(self):
2591+
# Async method assignment - should not be collected
2592+
async_method_var = "async_method"
2593+
return async_method_var
2594+
2595+
def sync_function():
2596+
# Function assignment - should not be collected
2597+
func_var = "function"
2598+
return func_var
2599+
2600+
async def async_function():
2601+
# Async function assignment - should not be collected
2602+
async_func_var = "async_function"
2603+
return async_func_var
2604+
2605+
# More global assignments
2606+
ANOTHER_CONSTANT = 100
2607+
FINAL_ASSIGNMENT = {"data": "value"}
2608+
"""
2609+
2610+
tree = cst.parse_module(source_code)
2611+
collector = GlobalAssignmentCollector()
2612+
tree.visit(collector)
2613+
2614+
# Should only collect global-level assignments
2615+
assert len(collector.assignments) == 3
2616+
assert "GLOBAL_CONSTANT" in collector.assignments
2617+
assert "ANOTHER_CONSTANT" in collector.assignments
2618+
assert "FINAL_ASSIGNMENT" in collector.assignments
2619+
2620+
# Should not collect assignments from inside any scoped blocks
2621+
assert "class_var" not in collector.assignments
2622+
assert "method_var" not in collector.assignments
2623+
assert "async_method_var" not in collector.assignments
2624+
assert "func_var" not in collector.assignments
2625+
assert "async_func_var" not in collector.assignments
2626+
2627+
# Verify correct order
2628+
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
2629+
assert collector.assignment_order == expected_order

tests/test_code_replacement.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
is_zero_diff,
1313
replace_functions_and_add_imports,
1414
replace_functions_in_file,
15+
OptimFunctionCollector,
1516
)
1617
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1718
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
@@ -3453,3 +3454,137 @@ def hydrate_input_text_actions_with_field_names(
34533454
main_file.unlink(missing_ok=True)
34543455

34553456
assert new_code == expected
3457+
3458+
3459+
# OptimFunctionCollector async function tests
3460+
def test_optim_function_collector_with_async_functions():
3461+
"""Test OptimFunctionCollector correctly collects async functions."""
3462+
import libcst as cst
3463+
3464+
source_code = """
3465+
def sync_function():
3466+
return "sync"
3467+
3468+
async def async_function():
3469+
return "async"
3470+
3471+
class TestClass:
3472+
def sync_method(self):
3473+
return "sync_method"
3474+
3475+
async def async_method(self):
3476+
return "async_method"
3477+
"""
3478+
3479+
tree = cst.parse_module(source_code)
3480+
collector = OptimFunctionCollector(
3481+
function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")},
3482+
preexisting_objects=None
3483+
)
3484+
tree.visit(collector)
3485+
3486+
# Should collect both sync and async functions
3487+
assert len(collector.modified_functions) == 4
3488+
assert (None, "sync_function") in collector.modified_functions
3489+
assert (None, "async_function") in collector.modified_functions
3490+
assert ("TestClass", "sync_method") in collector.modified_functions
3491+
assert ("TestClass", "async_method") in collector.modified_functions
3492+
3493+
3494+
def test_optim_function_collector_new_async_functions():
3495+
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
3496+
import libcst as cst
3497+
3498+
source_code = """
3499+
def existing_function():
3500+
return "existing"
3501+
3502+
async def new_async_function():
3503+
return "new_async"
3504+
3505+
def new_sync_function():
3506+
return "new_sync"
3507+
3508+
class ExistingClass:
3509+
async def new_class_async_method(self):
3510+
return "new_class_async"
3511+
"""
3512+
3513+
# Only existing_function is in preexisting objects
3514+
preexisting_objects = {("existing_function", ())}
3515+
3516+
tree = cst.parse_module(source_code)
3517+
collector = OptimFunctionCollector(
3518+
function_names=set(), # Not looking for specific functions
3519+
preexisting_objects=preexisting_objects
3520+
)
3521+
tree.visit(collector)
3522+
3523+
# Should identify new functions (both sync and async)
3524+
assert len(collector.new_functions) == 2
3525+
function_names = [func.name.value for func in collector.new_functions]
3526+
assert "new_async_function" in function_names
3527+
assert "new_sync_function" in function_names
3528+
3529+
# Should identify new class methods
3530+
assert "ExistingClass" in collector.new_class_functions
3531+
assert len(collector.new_class_functions["ExistingClass"]) == 1
3532+
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
3533+
3534+
3535+
def test_optim_function_collector_mixed_scenarios():
3536+
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
3537+
import libcst as cst
3538+
3539+
source_code = """
3540+
# Global functions
3541+
def global_sync():
3542+
pass
3543+
3544+
async def global_async():
3545+
pass
3546+
3547+
class ParentClass:
3548+
def __init__(self):
3549+
pass
3550+
3551+
def sync_method(self):
3552+
pass
3553+
3554+
async def async_method(self):
3555+
pass
3556+
3557+
class ChildClass:
3558+
async def child_async_method(self):
3559+
pass
3560+
3561+
def child_sync_method(self):
3562+
pass
3563+
"""
3564+
3565+
# Looking for specific functions
3566+
function_names = {
3567+
(None, "global_sync"),
3568+
(None, "global_async"),
3569+
("ParentClass", "sync_method"),
3570+
("ParentClass", "async_method"),
3571+
("ChildClass", "child_async_method")
3572+
}
3573+
3574+
tree = cst.parse_module(source_code)
3575+
collector = OptimFunctionCollector(
3576+
function_names=function_names,
3577+
preexisting_objects=None
3578+
)
3579+
tree.visit(collector)
3580+
3581+
# Should collect all specified functions (mix of sync and async)
3582+
assert len(collector.modified_functions) == 5
3583+
assert (None, "global_sync") in collector.modified_functions
3584+
assert (None, "global_async") in collector.modified_functions
3585+
assert ("ParentClass", "sync_method") in collector.modified_functions
3586+
assert ("ParentClass", "async_method") in collector.modified_functions
3587+
assert ("ChildClass", "child_async_method") in collector.modified_functions
3588+
3589+
# Should collect __init__ method
3590+
assert "ParentClass" in collector.modified_init_functions

0 commit comments

Comments
 (0)