|
12 | 12 | from codeflash.models.models import FunctionParent |
13 | 13 | from codeflash.optimization.optimizer import Optimizer |
14 | 14 | from codeflash.code_utils.code_replacer import replace_functions_and_add_imports |
15 | | -from codeflash.code_utils.code_extractor import add_global_assignments |
| 15 | +from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class HelperClass: |
@@ -2482,3 +2482,148 @@ def test_circular_deps(): |
2482 | 2482 | assert "import ApiClient" not in new_code, "Error: Circular dependency found" |
2483 | 2483 |
|
2484 | 2484 | assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" |
| 2485 | +def test_global_assignment_collector_with_async_function(): |
| 2486 | + """Test GlobalAssignmentCollector correctly identifies global assignments outside async functions.""" |
| 2487 | + import libcst as cst |
| 2488 | + |
| 2489 | + source_code = """ |
| 2490 | +# Global assignment |
| 2491 | +GLOBAL_VAR = "global_value" |
| 2492 | +OTHER_GLOBAL = 42 |
| 2493 | +
|
| 2494 | +async def async_function(): |
| 2495 | + # This should not be collected (inside async function) |
| 2496 | + local_var = "local_value" |
| 2497 | + INNER_ASSIGNMENT = "should_not_be_global" |
| 2498 | + return local_var |
| 2499 | +
|
| 2500 | +# Another global assignment |
| 2501 | +ANOTHER_GLOBAL = "another_global" |
| 2502 | +""" |
| 2503 | + |
| 2504 | + tree = cst.parse_module(source_code) |
| 2505 | + collector = GlobalAssignmentCollector() |
| 2506 | + tree.visit(collector) |
| 2507 | + |
| 2508 | + # Should collect global assignments but not the ones inside async function |
| 2509 | + assert len(collector.assignments) == 3 |
| 2510 | + assert "GLOBAL_VAR" in collector.assignments |
| 2511 | + assert "OTHER_GLOBAL" in collector.assignments |
| 2512 | + assert "ANOTHER_GLOBAL" in collector.assignments |
| 2513 | + |
| 2514 | + # Should not collect assignments from inside async function |
| 2515 | + assert "local_var" not in collector.assignments |
| 2516 | + assert "INNER_ASSIGNMENT" not in collector.assignments |
| 2517 | + |
| 2518 | + # Verify assignment order |
| 2519 | + expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"] |
| 2520 | + assert collector.assignment_order == expected_order |
| 2521 | + |
| 2522 | + |
| 2523 | +def test_global_assignment_collector_nested_async_functions(): |
| 2524 | + """Test GlobalAssignmentCollector handles nested async functions correctly.""" |
| 2525 | + import libcst as cst |
| 2526 | + |
| 2527 | + source_code = """ |
| 2528 | +# Global assignment |
| 2529 | +CONFIG = {"key": "value"} |
| 2530 | +
|
| 2531 | +def sync_function(): |
| 2532 | + # Inside sync function - should not be collected |
| 2533 | + sync_local = "sync" |
| 2534 | + |
| 2535 | + async def nested_async(): |
| 2536 | + # Inside nested async function - should not be collected |
| 2537 | + nested_var = "nested" |
| 2538 | + return nested_var |
| 2539 | + |
| 2540 | + return sync_local |
| 2541 | +
|
| 2542 | +async def async_function(): |
| 2543 | + # Inside async function - should not be collected |
| 2544 | + async_local = "async" |
| 2545 | + |
| 2546 | + def nested_sync(): |
| 2547 | + # Inside nested function - should not be collected |
| 2548 | + deeply_nested = "deep" |
| 2549 | + return deeply_nested |
| 2550 | + |
| 2551 | + return async_local |
| 2552 | +
|
| 2553 | +# Another global assignment |
| 2554 | +FINAL_GLOBAL = "final" |
| 2555 | +""" |
| 2556 | + |
| 2557 | + tree = cst.parse_module(source_code) |
| 2558 | + collector = GlobalAssignmentCollector() |
| 2559 | + tree.visit(collector) |
| 2560 | + |
| 2561 | + # Should only collect global-level assignments |
| 2562 | + assert len(collector.assignments) == 2 |
| 2563 | + assert "CONFIG" in collector.assignments |
| 2564 | + assert "FINAL_GLOBAL" in collector.assignments |
| 2565 | + |
| 2566 | + # Should not collect any assignments from inside functions |
| 2567 | + assert "sync_local" not in collector.assignments |
| 2568 | + assert "nested_var" not in collector.assignments |
| 2569 | + assert "async_local" not in collector.assignments |
| 2570 | + assert "deeply_nested" not in collector.assignments |
| 2571 | + |
| 2572 | + |
| 2573 | +def test_global_assignment_collector_mixed_async_sync_with_classes(): |
| 2574 | + """Test GlobalAssignmentCollector with async functions, sync functions, and classes.""" |
| 2575 | + import libcst as cst |
| 2576 | + |
| 2577 | + source_code = """ |
| 2578 | +# Global assignments |
| 2579 | +GLOBAL_CONSTANT = "constant" |
| 2580 | +
|
| 2581 | +class TestClass: |
| 2582 | + # Class-level assignment - should not be collected |
| 2583 | + class_var = "class_value" |
| 2584 | + |
| 2585 | + def sync_method(self): |
| 2586 | + # Method assignment - should not be collected |
| 2587 | + method_var = "method" |
| 2588 | + return method_var |
| 2589 | + |
| 2590 | + async def async_method(self): |
| 2591 | + # Async method assignment - should not be collected |
| 2592 | + async_method_var = "async_method" |
| 2593 | + return async_method_var |
| 2594 | +
|
| 2595 | +def sync_function(): |
| 2596 | + # Function assignment - should not be collected |
| 2597 | + func_var = "function" |
| 2598 | + return func_var |
| 2599 | +
|
| 2600 | +async def async_function(): |
| 2601 | + # Async function assignment - should not be collected |
| 2602 | + async_func_var = "async_function" |
| 2603 | + return async_func_var |
| 2604 | +
|
| 2605 | +# More global assignments |
| 2606 | +ANOTHER_CONSTANT = 100 |
| 2607 | +FINAL_ASSIGNMENT = {"data": "value"} |
| 2608 | +""" |
| 2609 | + |
| 2610 | + tree = cst.parse_module(source_code) |
| 2611 | + collector = GlobalAssignmentCollector() |
| 2612 | + tree.visit(collector) |
| 2613 | + |
| 2614 | + # Should only collect global-level assignments |
| 2615 | + assert len(collector.assignments) == 3 |
| 2616 | + assert "GLOBAL_CONSTANT" in collector.assignments |
| 2617 | + assert "ANOTHER_CONSTANT" in collector.assignments |
| 2618 | + assert "FINAL_ASSIGNMENT" in collector.assignments |
| 2619 | + |
| 2620 | + # Should not collect assignments from inside any scoped blocks |
| 2621 | + assert "class_var" not in collector.assignments |
| 2622 | + assert "method_var" not in collector.assignments |
| 2623 | + assert "async_method_var" not in collector.assignments |
| 2624 | + assert "func_var" not in collector.assignments |
| 2625 | + assert "async_func_var" not in collector.assignments |
| 2626 | + |
| 2627 | + # Verify correct order |
| 2628 | + expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"] |
| 2629 | + assert collector.assignment_order == expected_order |
0 commit comments