Skip to content

Conversation

@KRRT7
Copy link
Contributor

@KRRT7 KRRT7 commented Sep 26, 2025

User description

less messy


PR Type

Enhancement, Tests


Description

  • Add comprehensive async test suite

  • Enable async optimization pipeline end-to-end

  • Throughput-aware performance selection logic

  • Fix coverage utils empty-data handling


Diagram Walkthrough

flowchart LR
  instr["Async instrumentation decorators"] -- applied in baseline/candidate --> opt["FunctionOptimizer async paths"]
  opt -- run tests/benchmarks --> parse["Parse test results + throughput"]
  parse -- metrics --> critic["Critic: speedup + throughput gain"]
  critic -- selects best --> explain["Explanation with async throughput"]
Loading

File Walkthrough

Relevant files
Tests
3 files
test_async_run_and_parse_tests.py
End-to-end async run/parse tests and behaviors                     
+1039/-0
test_instrument_async_tests.py
Async decorator injection and profiling tests                       
+793/-0 
test_unused_helper_revert.py
Async helpers detection and revert scenarios                         
+556/-96
Enhancement
2 files
function_optimizer.py
Async-aware pipeline, benchmarking, throughput, explanations
+173/-29
critic.py
Add throughput gain and async-aware speedup logic               
+53/-10 
Bug fix
1 files
coverage_utils.py
Fix empty coverage return type and checks                               
+3/-3     
Additional files
29 files
e2e-async.yaml +69/-0   
async_bubble_sort.py +43/-0   
main.py +16/-0   
pyproject.toml +6/-0     
__init__.py [link]   
aiservice.py +13/-0   
code_utils.py +1/-9     
codeflash_wrap_decorator.py +167/-0 
config_consts.py +1/-0     
coverage_utils.py +3/-1     
edit_generated_tests.py +15/-5   
instrument_existing_tests.py +370/-2 
static_analysis.py +9/-3     
code_context_extractor.py +6/-2     
functions_to_optimize.py +35/-3   
models.py +4/-0     
optimizer.py +1/-1     
explanation.py +46/-4   
parse_test_output.py +24/-0   
pytest_plugin.py +23/-0   
pyproject.toml +1/-0     
end_to_end_test_async.py +27/-0   
test_add_runtime_comments.py +207/-1 
test_async_function_discovery.py +286/-0 
test_async_wrapper_sqlite_validation.py +285/-0 
test_code_context_extractor.py +155/-5 
test_code_replacement.py +162/-144
test_code_utils.py +115/-19
test_critic.py +163/-1 

- Add is_async parameter to AiServiceClient.optimize_python_code method
- Pass is_async flag to both optimization and test generation endpoints
- Update optimization pipeline to pass is_async from FunctionToOptimize
- Remove async function blocking check in function optimizer
- Remove unused has_any_async_functions import

This enables async functions to be properly processed by the optimization
pipeline and sent to the AI service with the correct async context.
- Added comprehensive async test instrumentation (AsyncCallInstrumenter class)
- Implemented async decorator functions (add_async_decorator_to_function, instrument_source_module_with_async_decorators)
- Added async wrapper decorators (codeflash_behavior_async, codeflash_performance_async)
- Updated edit_generated_tests.py to handle AsyncFunctionDef nodes in test parsing
- Updated coverage_utils.py to include async functions in coverage analysis
- Add async throughput fields to Explanation dataclass
- Implement throughput-based performance improvement calculation
- Add MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD configuration constant
- Update explanation logic to prefer throughput metrics for async functions
- Restore LSP compatibility with conditional test result display
in my case, i've started using symlinks a bit more often, and the current impl causes issues, we need to resolve the symlinked path too.
@github-actions
Copy link

github-actions bot commented Sep 26, 2025

PR Reviewer Guide 🔍

(Review updated until commit 6795a5c)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 5 🔵🔵🔵🔵🔵
🧪 PR contains tests
🔒 No security concerns identified
⚡ Recommended focus areas for review

Possible Issue

The change to use Path.resolve() before relative_to() in module_name_from_file_path can alter behavior when symlinks or different drive letters are involved, and may break expectations for paths within a non-resolved project root. Verify this does not mis-compute module names in workspace layouts or Windows drive mismatch scenarios.

def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
    try:
        relative_path = file_path.resolve().relative_to(project_root_path.resolve())
        return relative_path.with_suffix("").as_posix().replace("/", ".")
    except ValueError:
Behavior Change

The previous hard block on async functions was removed; async instrumentation is now auto-applied in multiple places (baseline and candidates). Confirm that source rewriting and restoration are always exception-safe and atomic across all try/finally blocks, especially when coverage/benchmarking toggles are used, to avoid leaving instrumented code in the repo.

test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)

if self.function_to_optimize.is_async:
    from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators

    success, instrumented_source = instrument_source_module_with_async_decorators(
        self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
    )
    if success and instrumented_source:
        with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
            f.write(instrumented_source)
        logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")

# Instrument codeflash capture
with progress_bar("Running tests to establish original code behavior..."):
    try:
        instrument_codeflash_capture(
            self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
        )
        total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE
        behavioral_results, coverage_results = self.run_and_parse_tests(
            testing_type=TestingMode.BEHAVIOR,
            test_env=test_env,
            test_files=self.test_files,
            optimization_iteration=0,
            testing_time=total_looping_time,
            enable_coverage=test_framework == "pytest",
            code_context=code_context,
        )
    finally:
        # Remove codeflash capture
        self.write_code_and_helpers(
            self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
        )
if not behavioral_results:
    logger.warning(
        f"force_lsp|Couldn't run any tests for original function {self.function_to_optimize.function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
    )
    console.rule()
    return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
if not coverage_critic(coverage_results, self.args.test_framework):
    return Failure(
        f"Test coverage is {coverage_results.coverage}%, which is below the required threshold of {COVERAGE_THRESHOLD}%."
    )

if test_framework == "pytest":
    with progress_bar("Running line profiling to identify performance bottlenecks..."):
        line_profile_results = self.line_profiler_step(
            code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
        )
    console.rule()
    with progress_bar("Running performance benchmarks..."):
        if self.function_to_optimize.is_async:
            from codeflash.code_utils.instrument_existing_tests import (
                instrument_source_module_with_async_decorators,
            )

            success, instrumented_source = instrument_source_module_with_async_decorators(
                self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
            )
            if success and instrumented_source:
                with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
                    f.write(instrumented_source)
                logger.debug(
                    f"Applied async performance instrumentation to {self.function_to_optimize.file_path}"
                )

        try:
            benchmarking_results, _ = self.run_and_parse_tests(
                testing_type=TestingMode.PERFORMANCE,
                test_env=test_env,
                test_files=self.test_files,
                optimization_iteration=0,
                testing_time=total_looping_time,
                enable_coverage=False,
                code_context=code_context,
            )
        finally:
            if self.function_to_optimize.is_async:
                self.write_code_and_helpers(
                    self.function_to_optimize_source_code,
                    original_helper_code,
                    self.function_to_optimize.file_path,
                )
else:
Result Parsing

run_and_parse_tests now sets results.perf_stdout only for PERFORMANCE and returns early; ensure all callers expecting perf_stdout or coverage_results still receive them consistently and that line-profiler branch still populates coverage where needed.

        coverage_database_file=coverage_database_file,
        coverage_config_file=coverage_config_file,
    )
    if testing_type == TestingMode.PERFORMANCE:
        results.perf_stdout = run_result.stdout
    return results, coverage_results
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
return results, coverage_results

@github-actions
Copy link

github-actions bot commented Sep 26, 2025

PR Code Suggestions ✨

Latest suggestions up to 6795a5c
Explore these optional code suggestions:

CategorySuggestion                                                                                                                                    Impact
Possible issue
Restore source after async instrumentation

Instrumenting the source file in-place to add async decorators permanently mutates
the codebase before establishing the baseline. Wrap this mutation in a try/finally
and restore the original source immediately after the run to avoid leaking
instrumentation across subsequent steps. Persist the original code text beforehand
and write it back in the finally block.

codeflash/optimization/function_optimizer.py [1528-1541]

+original_source = None
 if self.function_to_optimize.is_async:
     from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
-
     success, instrumented_source = instrument_source_module_with_async_decorators(
         self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
     )
     if success and instrumented_source:
+        original_source = self.function_to_optimize.file_path.read_text(encoding="utf8")
         with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
             f.write(instrumented_source)
         logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")
+try:
+    # ... run behavioral tests as before ...
+    pass
+finally:
+    if original_source is not None:
+        with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
+            f.write(original_source)
Suggestion importance[1-10]: 7

__

Why: The PR adds in-place async instrumentation during baseline establishment without immediate restoration; wrapping with try/finally to restore the original file prevents leaking decorators. This is accurate and reduces side effects, though the PR later restores around benchmarking; adding it here still improves safety.

Medium
Limit restoration to mutated file

Restoring helper modules on every candidate benchmarking run may clobber concurrent
state and is broader than needed. Save and restore only the target file you mutated,
using a local copy of its pre-instrumented contents in a try/finally, to avoid
unintended helper rewrites and race conditions.

codeflash/optimization/function_optimizer.py [1765-1799]

+original_source = None
 if test_framework == "pytest":
-    # For async functions, instrument at definition site for performance benchmarking
     if self.function_to_optimize.is_async:
-        from codeflash.code_utils.instrument_existing_tests import (
-            instrument_source_module_with_async_decorators,
-        )
-
+        from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
         success, instrumented_source = instrument_source_module_with_async_decorators(
             self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
         )
         if success and instrumented_source:
+            original_source = self.function_to_optimize.file_path.read_text(encoding="utf8")
             with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
                 f.write(instrumented_source)
             logger.debug(
                 f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
             )
-
     try:
         candidate_benchmarking_results, _ = self.run_and_parse_tests(
             testing_type=TestingMode.PERFORMANCE,
             test_env=test_env,
             test_files=self.test_files,
             optimization_iteration=optimization_candidate_index,
             testing_time=total_looping_time,
             enable_coverage=False,
         )
     finally:
-        # Restore original source if we instrumented it
-        if self.function_to_optimize.is_async:
-            self.write_code_and_helpers(
-                candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
-            )
+        if original_source is not None:
+            with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
+                f.write(original_source)
Suggestion importance[1-10]: 6

__

Why: The current code restores all helpers via write_code_and_helpers after candidate perf runs; restoring only the mutated source via local snapshot narrows blast radius and avoids unnecessary writes. It's a reasonable maintainability improvement; impact is moderate.

Low
Avoid KeyError on env lookups

Accessing environment variables with direct indexing will raise KeyError and crash
tests if the plugin hasn't set them yet. Fallback to safe defaults and validate with
a clear error only when truly unavailable. This prevents unexpected crashes in
environments that don't use the plugin.

codeflash/code_utils/codeflash_wrap_decorator.py [33-43]

 def extract_test_context_from_env() -> tuple[str, str | None, str]:
-    test_module = os.environ["CODEFLASH_TEST_MODULE"]
+    test_module = os.environ.get("CODEFLASH_TEST_MODULE")
     test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
-    test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
+    test_function = os.environ.get("CODEFLASH_TEST_FUNCTION")
 
     if test_module and test_function:
         return (test_module, test_class if test_class else None, test_function)
 
     raise RuntimeError(
-        "Test context environment variables not set - ensure tests are run through codeflash test runner"
+        "Test context environment variables not set - ensure tests are run through the Codeflash test runner or pytest plugin"
     )
Suggestion importance[1-10]: 6

__

Why: The change replaces direct env indexing with .get to avoid KeyError and improves the error message; it aligns with the new pytest plugin that sets these vars but adds robustness. It's a minor reliability improvement, not critical.

Low
Correct test parsing branches

The nested performance check is redundant and the non-performance branch parses a
line profiler file even for behavioral runs. Split branches clearly: PERFORMANCE
uses parse_test_results; LINE_PROFILE uses parse_line_profile_results; BEHAVIOR uses
parse_test_results. This prevents mis-parsing and missing perf stdout in some paths.

codeflash/optimization/function_optimizer.py [1950-1957]

 if testing_type == TestingMode.PERFORMANCE:
     results, coverage_results = parse_test_results(
         test_output=run_result.stdout,
         coverage_database_file=coverage_database_file,
         coverage_config_file=coverage_config_file,
     )
-    if testing_type == TestingMode.PERFORMANCE:
-        results.perf_stdout = run_result.stdout
+    results.perf_stdout = run_result.stdout
     return results, coverage_results
-results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
-return results, coverage_results
+elif testing_type == TestingMode.LINE_PROFILE:
+    results, coverage_results = parse_line_profile_results(
+        line_profiler_output_file=line_profiler_output_file
+    )
+    return results, coverage_results
+else:
+    # Default to behavioral parsing
+    results, coverage_results = parse_test_results(
+        test_output=run_result.stdout,
+        coverage_database_file=coverage_database_file,
+        coverage_config_file=coverage_config_file,
+    )
+    return results, coverage_results
Suggestion importance[1-10]: 5

__

Why: The suggestion clarifies branching for PERFORMANCE vs LINE_PROFILE vs BEHAVIOR; the existing PR already attaches perf_stdout under PERFORMANCE and otherwise parses line-profile output. While the refactor may improve clarity, it's not strictly required for correctness given current call sites.

Low
General
Harden empty DB detection

Calling stat() on a non-existent path after exists() can still raise on race
conditions. Swap the checks to avoid TOCTOU issues and guard with try/except to
ensure a safe fallback without crashing coverage loading.

codeflash/verification/coverage_utils.py [41-46]

-if not database_path.exists() or not database_path.stat().st_size:
-    logger.debug(f"Coverage database {database_path} is empty or does not exist")
-    sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
+try:
+    if not database_path.exists() or database_path.stat().st_size == 0:
+        logger.debug(f"Coverage database {database_path} is empty or does not exist")
+        sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
+        return CoverageData.create_empty(source_code_path, function_name, code_context)
+    cov.load()
+except FileNotFoundError:
+    logger.debug(f"Coverage database {database_path} not found during load")
+    sentry_sdk.capture_message(f"Coverage database {database_path} not found during load")
     return CoverageData.create_empty(source_code_path, function_name, code_context)
-cov.load()
Suggestion importance[1-10]: 7

__

Why: Wrapping existence/size checks with try/except for FileNotFoundError reduces TOCTOU issues and safely falls back to empty coverage. It matches the new code location and improves robustness without altering behavior.

Medium
Fix timeout decorator detection/addition

The timeout decorator is constructed as a Call but checked as a Name equality, so
duplicates can be added and unparsing may produce invalid code. Make the check
robust by matching both Name and Attribute/Call shapes, and add the decorator as a
Decorator node consistently.

codeflash/code_utils/instrument_existing_tests.py [318-336]

 if self.test_framework == "unittest":
-    timeout_decorator = ast.Call(
-        func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
-        args=[ast.Constant(value=15)],
-        keywords=[],
-    )
+    decorator_name = "timeout_decorator.timeout"
     for item in node.body:
-        if (
-            isinstance(item, ast.FunctionDef)
-            and item.name.startswith("test_")
-            and not any(
-                isinstance(d, ast.Call)
-                and isinstance(d.func, ast.Name)
-                and d.func.id == "timeout_decorator.timeout"
+        if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name.startswith("test_"):
+            has_timeout = any(
+                (isinstance(d, ast.Call) and (
+                    (isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator")
+                    or (isinstance(d.func, ast.Attribute) and getattr(d.func.value, "id", None) == "timeout_decorator" and d.func.attr == ast.Name(id="timeout", ctx=ast.Load()))
+                ))
                 for d in item.decorator_list
             )
-        ):
-            item.decorator_list.append(timeout_decorator)
+            if not has_timeout:
+                item.decorator_list.append(
+                    ast.Call(
+                        func=ast.Attribute(value=ast.Name(id="timeout_decorator", ctx=ast.Load()), attr="timeout", ctx=ast.Load()),
+                        args=[ast.Constant(value=15)],
+                        keywords=[],
+                    )
+                )
 return self.generic_visit(node)
Suggestion importance[1-10]: 3

__

Why: The goal (robustly detect/add timeout decorator) is reasonable, but the improved code mixes Name/Attribute checks incorrectly (e.g., comparing attr to an AST node) and changes shapes inconsistently; risk of introducing bugs. Limited accuracy lowers the score.

Low

Previous suggestions

Suggestions up to commit 4325416
CategorySuggestion                                                                                                                                    Impact
Possible issue
Avoid blocking event loop

Replace the blocking time.sleep with asyncio.sleep to avoid blocking the event loop
in an async function. This prevents starvation and ensures accurate async throughput
measurements.

code_to_optimize/code_directories/async_e2e/main.py [1-16]

-import time
 import asyncio
 
 
 async def retry_with_backoff(func, max_retries=3):
     if max_retries < 1:
         raise ValueError("max_retries must be at least 1")
     last_exception = None
     for attempt in range(max_retries):
         try:
             return await func()
         except Exception as e:
             last_exception = e
             if attempt < max_retries - 1:
-                time.sleep(0.0001 * attempt)
+                await asyncio.sleep(0.0001 * attempt)
     raise last_exception
Suggestion importance[1-10]: 9

__

Why: Replacing time.sleep with await asyncio.sleep in an async function prevents blocking the event loop and aligns with async best practices; the improved code correctly reflects the change and removes the unused time import.

High
Safely add required imports without duplication

The injected os.environ[...] assignments rely on os being imported, but adding the
import unconditionally at the module top may duplicate existing imports or break
module docstrings/comments ordering. Safely ensure os is imported only if missing,
and preserve module header nodes. Also avoid returning success when no env
assignment was added.

codeflash/code_utils/instrument_existing_tests.py [489-523]

 def inject_async_profiling_into_existing_test(
     test_path: Path,
     call_positions: list[CodePosition],
     function_to_optimize: FunctionToOptimize,
     tests_project_root: Path,
     test_framework: str,
     mode: TestingMode = TestingMode.BEHAVIOR,
 ) -> tuple[bool, str | None]:
     """Inject profiling for async function calls by setting environment variables before each call."""
     with test_path.open(encoding="utf8") as f:
         test_code = f.read()
-    
     try:
         tree = ast.parse(test_code)
     except SyntaxError:
         logger.exception(f"Syntax error in code in file - {test_path}")
         return False, None
 
     test_module_path = module_name_from_file_path(test_path, tests_project_root)
-
     import_visitor = FunctionImportedAsVisitor(function_to_optimize)
     import_visitor.visit(tree)
     func = import_visitor.imported_as
 
     async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
     tree = async_instrumenter.visit(tree)
 
     if not async_instrumenter.did_instrument:
         return False, None
 
-    # Add necessary imports
-    new_imports = [ast.Import(names=[ast.alias(name="os")])]
+    # Ensure required imports exist without duplicating or disturbing module header
+    has_os_import = any(
+        isinstance(n, (ast.Import, ast.ImportFrom))
+        and any(getattr(alias, "name", None) == "os" or getattr(alias, "module", None) == "os" for alias in getattr(n, "names", []))
+        for n in tree.body
+        if isinstance(n, (ast.Import, ast.ImportFrom))
+    )
+    new_body = list(tree.body)
+    insert_at = 0
+    # Keep module docstring at the top if present
+    if new_body and isinstance(new_body[0], ast.Expr) and isinstance(getattr(new_body[0], "value", None), ast.Constant) and isinstance(new_body[0].value.value, str):
+        insert_at = 1
+
+    if not has_os_import:
+        new_body.insert(insert_at, ast.Import(names=[ast.alias(name="os")]))
+
     if test_framework == "unittest":
-        new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
+        has_timeout_import = any(
+            isinstance(n, ast.Import) and any(a.name == "timeout_decorator" for a in n.names)
+            for n in new_body
+        )
+        if not has_timeout_import:
+            new_body.insert(insert_at + (0 if not has_os_import else 1), ast.Import(names=[ast.alias(name="timeout_decorator")]))
 
-    tree.body = [*new_imports, *tree.body]
+    tree.body = new_body
     return True, isort.code(ast.unparse(tree), float_to_top=True)
Suggestion importance[1-10]: 6

__

Why: Correctly notes potential duplicate or misplaced imports; proposed change preserves docstrings and avoids duplication while keeping behavior. Moderate impact on robustness, accurate to the diff context.

Low
Limit env injection to simple await statements

The environment assignment is inserted before the entire statement containing the
await, which can be a compound statement or non-simple node, risking invalid AST or
misordered execution. Restrict instrumentation to simple statements and inject the
env assignment immediately before the specific Await statement at the same block
level.

codeflash/code_utils/instrument_existing_tests.py [399-411]

 class AsyncCallInstrumenter(ast.NodeTransformer):
     ...
     def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]:
-        for node in ast.walk(stmt):
-            if (
-                isinstance(node, ast.Await)
-                and isinstance(node.value, ast.Call)
-                and self._is_target_call(node.value)
-                and self._call_in_positions(node.value)
-            ):
-                # Check if this call is in one of our target positions
-                return stmt, True  # Return original statement but signal we added env var
-
+        # Only instrument simple Expr statements that await the target call
+        if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Await) and isinstance(stmt.value.value, ast.Call):
+            call = stmt.value.value
+            if self._is_target_call(call) and self._call_in_positions(call):
+                return stmt, True
         return stmt, False
Suggestion importance[1-10]: 3

__

Why: While it reduces risk of invalid AST edits, it limits coverage and may miss valid await usages inside more complex statements; correctness trade-off makes it a minor, debatable improvement.

Low
General
Ensure nested defs are visited

Use the node visitor's generic traversal to ensure nested functions and deeper
constructs are visited. Manually iterating only over node.body can miss nested defs,
leading to incomplete annotation updates.

codeflash/code_utils/edit_generated_tests.py [33-41]

 def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
     self.context_stack.append(node.name)
+    # Visit immediate children via NodeVisitor to ensure nested defs are reached
     for inner_node in node.body:
-        if isinstance(inner_node, ast.FunctionDef):
-            self.visit_FunctionDef(inner_node)
-        elif isinstance(inner_node, ast.AsyncFunctionDef):
-            self.visit_AsyncFunctionDef(inner_node)
+        self.visit(inner_node)
     self.context_stack.pop()
     return node
Suggestion importance[1-10]: 7

__

Why: Using self.visit(inner_node) ensures deeper nested nodes are traversed, addressing potential misses from only iterating node.body; this improves robustness though existing code already handles async and sync immediate defs.

Medium
Validate throughput inputs

Guard against negative or None inputs and ensure non-negative throughput values
before computing gain. This avoids misleading gains or division logic issues if
inputs are invalid.

codeflash/result/critic.py [29-37]

 def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> float:
-    """Calculate the throughput gain of an optimized code over the original code.
+    """Calculate the throughput gain of an optimized code over the original code."""
+    if original_throughput is None or optimized_throughput is None:
+        return 0.0
+    # Clamp to non-negative to avoid misleading ratios
+    original = max(0, original_throughput)
+    optimized = max(0, optimized_throughput)
+    if original == 0:
+        return 0.0
+    return (optimized - original) / original
 
-    This value multiplied by 100 gives the percentage improvement in throughput.
-    For throughput, higher values are better (more executions per time period).
-    """
-    if original_throughput == 0:
-        return 0.0
-    return (optimized_throughput - original_throughput) / original_throughput
-
Suggestion importance[1-10]: 5

__

Why: Adding input validation is reasonable but partially redundant since callers already gate None and negatives likely indicate upstream issues; the change is minor and not critical to correctness here.

Low
Fix alias handling to match calls

When tracking imported aliases, the parents may refer to classes that don't match
the import context, causing false positives in call matching. Preserve the original
qualified_name behavior by clearing parents for unqualified imports so call
detection matches top-level usage.

codeflash/code_utils/instrument_existing_tests.py [451-465]

 class FunctionImportedAsVisitor(ast.NodeVisitor):
     ...
     def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
-        ...
-                    self.imported_as = FunctionToOptimize(
-                        function_name=alias,  # type: ignore[arg-type]
-                        parents=self.function.parents,
-                        file_path=self.function.file_path,
-                        starting_line=self.function.starting_line,
-                        ending_line=self.function.ending_line,
-                        is_async=self.function.is_async,
-                    )
-            else:
+        module_name = ".".join([n for n in node.module.split(".") if n]) if node.module else ""
+        for alias in node.names:
+            # imported function directly
+            if getattr(alias, "name", None) == self.function.function_name and module_name.endswith(self.original_module_name()):
+                imported_name = alias.asname or alias.name
                 self.imported_as = FunctionToOptimize(
-                    function_name=alias.asname,  # type: ignore[arg-type]
-                    parents=self.function.parents,
+                    function_name=imported_name,  # type: ignore[arg-type]
+                    parents=[],  # imported into module scope; match unqualified calls
                     file_path=self.function.file_path,
                     starting_line=self.function.starting_line,
                     ending_line=self.function.ending_line,
                     is_async=self.function.is_async,
                 )
+                return
+    def original_module_name(self) -> str:
+        return self.function.file_path.stem
Suggestion importance[1-10]: 2

__

Why: The current code already propagates is_async and parents appropriately; clearing parents based only on import form may cause incorrect matching. The proposal is speculative and risks breaking qualified call detection.

Low

@github-actions github-actions bot added the workflow-modified This PR modifies GitHub Actions workflows label Sep 26, 2025
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 26, 2025

⚡️ Codeflash found optimizations for this PR

📄 24% (0.24x) speedup for InjectPerfOnly.find_and_update_line_node in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 21.6 milliseconds 17.4 milliseconds (best of 49 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch clean-async-branch).

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 26, 2025

⚡️ Codeflash found optimizations for this PR

📄 30% (0.30x) speedup for AsyncCallInstrumenter._process_test_function in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 2.08 milliseconds 1.61 milliseconds (best of 14 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch clean-async-branch).

KRRT7 and others added 2 commits September 26, 2025 18:22
The optimization achieves a 22% speedup by eliminating redundant regex compilation and reducing unnecessary string operations.

**Key optimizations:**

1. **Pre-compiled regex patterns**: The original code compiled the same regex pattern multiple times (3,114 compilations taking 43.4% of total time). The optimized version compiles each pattern only once upfront using `_compile_function_patterns()`, moving this expensive operation outside the nested loops.

2. **Efficient string manipulation**: Instead of using `re.sub()` which searches the entire string again, the optimized version uses `finditer()` to get match positions directly, then performs string slicing (`source[:start] + source[end:]`) to remove matched functions. This avoids the overhead of regex substitution.

3. **Early termination**: After finding and removing a function match, the code breaks from the inner loop since only one match per function is expected, preventing unnecessary continued iteration.

**Performance impact by test case:**
- The optimizations are most effective for scenarios with multiple test functions to remove across multiple generated tests (the typical use case)
- For edge cases like empty test lists, there's minimal overhead from pre-compilation but no significant benefit
- The approach maintains correct behavior for decorated functions (skipping `@pytest.mark.parametrize` functions as intended)

The line profiler shows the regex compilation time dropped from 43.4% to being absorbed into the 89.8% upfront compilation cost, while the substitution overhead (51.7% in original) is eliminated entirely.
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 27, 2025

⚡️ Codeflash found optimizations for this PR

📄 23% (0.23x) speedup for remove_functions_from_generated_tests in codeflash/code_utils/edit_generated_tests.py

⏱️ Runtime : 1.46 milliseconds 1.19 milliseconds (best of 11 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch clean-async-branch).

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 27, 2025

This PR is now faster! 🚀 @KRRT7 accepted my optimizations from:

The optimized code achieves a **123% speedup** by replacing expensive AST traversal operations with more efficient alternatives:

**Key Optimizations:**

1. **Decorator Search Optimization**: Replaced the `any()` generator expression with a simple loop that breaks early when finding `timeout_decorator.timeout`. This avoids unnecessary attribute lookups and iterations through the decorator list, especially beneficial when the decorator is found early or when there are many decorators.

2. **AST Traversal Replacement**: The most significant optimization replaces `ast.walk(stmt)` with a manual stack-based depth-first search in `_optimized_instrument_statement()`. The original `ast.walk()` creates a list of every node in the AST subtree, which is memory-intensive and includes many irrelevant nodes. The optimized version:
   - Uses a stack to traverse nodes manually
   - Only explores child nodes via `_fields` attribute access
   - Immediately returns when finding an `ast.Await` node that matches criteria
   - Avoids creating intermediate collections

**Performance Impact by Test Case:**
- **Large-scale tests** see the biggest improvements (125-129% faster) because they have many statements to traverse
- **Nested structures** benefit significantly (57-93% faster) as the optimization avoids deep, unnecessary traversals
- **Simple test cases** still see 29-48% improvements from the decorator optimization
- **Functions with many await calls** show excellent scaling (123-127% faster) due to reduced per-statement traversal costs

The line profiler shows the critical bottleneck was in `_instrument_statement()` (96.4% of time originally), which is now reduced to 93.3% but with much lower absolute time, demonstrating the effectiveness of the AST traversal optimization.
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 27, 2025

⚡️ Codeflash found optimizations for this PR

📄 123% (1.23x) speedup for AsyncCallInstrumenter.visit_AsyncFunctionDef in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 9.25 milliseconds 4.14 milliseconds (best of 186 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch clean-async-branch).

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 27, 2025

⚡️ Codeflash found optimizations for this PR

📄 39% (0.39x) speedup for get_first_top_level_function_or_method_ast in codeflash/code_utils/static_analysis.py

⏱️ Runtime : 1.43 milliseconds 1.03 milliseconds (best of 61 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch clean-async-branch).

…25-09-27T02.50.03

⚡️ Speed up method `AsyncCallInstrumenter.visit_AsyncFunctionDef` by 123% in PR #769 (`clean-async-branch`)
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 27, 2025

misrasaurabh1
misrasaurabh1 previously approved these changes Sep 29, 2025
@KRRT7 KRRT7 enabled auto-merge September 29, 2025 19:23
@KRRT7 KRRT7 requested a review from misrasaurabh1 September 29, 2025 20:59
Comment on lines +325 to +335
if (
isinstance(item, ast.FunctionDef)
and item.name.startswith("test_")
and not any(
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
for d in item.decorator_list
)
):
item.decorator_list.append(timeout_decorator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 10% (0.10x) speedup for AsyncCallInstrumenter.visit_ClassDef in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 4.11 milliseconds 3.72 milliseconds (best of 216 runs)

📝 Explanation and details

The optimization replaces the computationally expensive any() generator expression with an explicit loop and early break pattern. In the original code, any() creates a generator that evaluates a complex condition for each decorator in item.decorator_list, even when an early match is found. The optimized version uses a simple flag-based approach with has_timeout_decorator = False and an explicit loop that breaks immediately upon finding the target decorator.

This change eliminates the overhead of:

  • Generator function creation and iteration protocol
  • Repeated evaluation of the complex isinstance() chain for all decorators when early matches occur
  • The any() builtin call overhead

The optimization is particularly effective for test cases with many test methods (13.5% speedup on 500 test functions) because it reduces the per-method decorator checking cost. For smaller classes, the speedup is more modest (5-11%) but still consistent. The explicit loop pattern allows the CPU to exit the decorator search as soon as a timeout decorator is found, rather than continuing to evaluate the generator expression.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 50 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter


# Dummy classes for dependencies
class FunctionToOptimize:
    def __init__(self, function_name, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

class CodePosition:
    pass

class TestingMode:
    BEHAVIOR = "behavior"
    OTHER = "other"
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter


# Helper to parse code and return AST
def parse_classdef(code: str) -> ast.ClassDef:
    mod = ast.parse(code)
    for node in mod.body:
        if isinstance(node, ast.ClassDef):
            return node
    raise ValueError("No ClassDef found in code")

# Helper to check if a function has the timeout decorator
def has_timeout_decorator(func: ast.FunctionDef) -> bool:
    for dec in func.decorator_list:
        if (
            isinstance(dec, ast.Call)
            and isinstance(dec.func, ast.Name)
            and dec.func.id == "timeout_decorator.timeout"
        ):
            return True
    return False

# ---------------- BASIC TEST CASES ----------------

def test_no_test_methods_unittest():
    # Class with no test methods: nothing should be decorated
    code = """
class MyTestCase:
    def helper(self): pass
    def not_a_test(self): pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 10.2μs -> 10.2μs (0.098% slower)
    # No function should have the timeout decorator
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            pass

def test_single_test_method_unittest():
    # Class with one test method: should be decorated
    code = """
class MyTestCase:
    def test_simple(self): pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 12.2μs -> 11.6μs (5.90% faster)
    # The test_simple method should have the timeout decorator
    for item in result.body:
        if isinstance(item, ast.FunctionDef) and item.name == "test_simple":
            pass

def test_multiple_test_methods_unittest():
    # Class with multiple test methods: all should be decorated
    code = """
class MyTestCase:
    def test_one(self): pass
    def test_two(self): pass
    def helper(self): pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 15.2μs -> 14.2μs (7.18% faster)
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            if item.name.startswith("test_"):
                pass
            else:
                pass

def test_non_unittest_framework():
    # Should not decorate anything if not using unittest
    code = """
class MyTestCase:
    def test_one(self): pass
    def test_two(self): pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "pytest",  # not unittest
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 9.34μs -> 9.33μs (0.096% faster)
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            pass

def test_already_has_timeout_decorator():
    # Should not add duplicate timeout decorator
    code = """
class MyTestCase:
    @timeout_decorator.timeout(15)
    def test_one(self): pass
    def test_two(self): pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 14.9μs -> 13.6μs (9.52% faster)
    found = 0
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            if item.name == "test_one":
                found += 1
            elif item.name == "test_two":
                found += 1

# ---------------- EDGE TEST CASES ----------------

def test_decorator_list_is_empty():
    # Handles function with no decorators
    code = """
class MyTestCase:
    def test_empty_decorator(self): pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 11.3μs -> 10.6μs (7.12% faster)
    for item in result.body:
        if isinstance(item, ast.FunctionDef) and item.name == "test_empty_decorator":
            pass

def test_non_function_items_in_class():
    # Handles classes with non-function items
    code = """
class MyTestCase:
    x = 5
    def test_x(self): pass
    @property
    def y(self): return 10
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 23.0μs -> 22.0μs (4.69% faster)
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            if item.name == "test_x":
                pass
            elif item.name == "y":
                pass

def test_function_name_edge_cases():
    # Only functions starting with test_ get decorated
    code = """
class MyTestCase:
    def test(self): pass
    def testCase(self): pass
    def test_abc(self): pass
    def test123(self): pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 14.3μs -> 13.4μs (6.51% faster)
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            if item.name == "test_abc":
                pass
            else:
                pass

def test_decorator_is_not_call():
    # Handles decorators that are not ast.Call
    code = """
class MyTestCase:
    @staticmethod
    def test_static(self): pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 11.9μs -> 10.8μs (9.88% faster)
    for item in result.body:
        if isinstance(item, ast.FunctionDef) and item.name == "test_static":
            # Should have both staticmethod and timeout_decorator
            found_static = any(isinstance(dec, ast.Name) and dec.id == "staticmethod" for dec in item.decorator_list)

def test_empty_class_body():
    # Class with no body should not raise
    code = """
class MyTestCase:
    pass
"""
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    # Should not raise
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 8.12μs -> 8.03μs (1.12% faster)

# ---------------- LARGE SCALE TEST CASES ----------------

def test_large_number_of_test_methods():
    # Large class with many test methods
    N = 500
    code = "class BigTestCase:\n" + "\n".join(
        f"    def test_{i}(self): pass" for i in range(N)
    )
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 954μs -> 840μs (13.6% faster)
    count = 0
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            count += 1

def test_large_class_mixed_methods():
    # Large class with mixed test and non-test methods
    N = 250
    code = "class MixedTestCase:\n" + "\n".join(
        f"    def test_{i}(self): pass" for i in range(N)
    ) + "\n" + "\n".join(
        f"    def helper_{i}(self): pass" for i in range(N)
    )
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "unittest",
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 641μs -> 569μs (12.5% faster)
    test_count = helper_count = 0
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            if item.name.startswith("test_"):
                test_count += 1
            else:
                helper_count += 1

def test_large_non_unittest():
    # Large class with unittest-like methods but non-unittest framework
    N = 300
    code = "class BigTestCase:\n" + "\n".join(
        f"    def test_{i}(self): pass" for i in range(N)
    )
    classdef = parse_classdef(code)
    instrumenter = AsyncCallInstrumenter(
        FunctionToOptimize("foo"),
        "mod.py",
        "pytest",  # not unittest
        [],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); result = codeflash_output # 367μs -> 356μs (3.08% faster)
    for item in result.body:
        if isinstance(item, ast.FunctionDef):
            pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations

import ast

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter


class DummyFunction:
    def __init__(self, function_name, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

# Helper to parse code and get ClassDef node
def get_classdef_node(source: str) -> ast.ClassDef:
    tree = ast.parse(source)
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            return node
    raise ValueError("No ClassDef found in source")

# Helper to check if a functiondef has the timeout decorator
def has_timeout_decorator(funcdef: ast.FunctionDef) -> bool:
    for d in funcdef.decorator_list:
        if (
            isinstance(d, ast.Call)
            and isinstance(d.func, ast.Name)
            and d.func.id == "timeout_decorator.timeout"
            and len(d.args) == 1
            and isinstance(d.args[0], ast.Constant)
            and d.args[0].value == 15
        ):
            return True
    return False

# Basic Test Cases

def test_unittest_adds_timeout_to_test_functions():
    # Test that the decorator is added to test functions for unittest
    src = '''
class TestFoo:
    def test_one(self): pass
    def test_two(self): pass
    def helper(self): pass
    '''
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_one"),
        module_path="foo.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 16.1μs -> 15.1μs (6.38% faster)
    test_funcs = [item for item in new_node.body if isinstance(item, ast.FunctionDef)]
    # Only test_ functions should get the decorator
    for func in test_funcs:
        if func.name.startswith("test_"):
            pass
        else:
            pass

def test_pytest_does_not_add_timeout():
    # Test that no decorator is added for pytest
    src = '''
class TestBar:
    def test_a(self): pass
    def test_b(self): pass
    '''
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_a"),
        module_path="bar.py",
        test_framework="pytest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 9.48μs -> 9.42μs (0.637% faster)
    for func in [item for item in new_node.body if isinstance(item, ast.FunctionDef)]:
        pass

def test_no_test_functions():
    # Class with no test_ functions should not get any decorator
    src = '''
class HelperClass:
    def foo(self): pass
    def bar(self): pass
    '''
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("foo"),
        module_path="helper.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 9.62μs -> 9.32μs (3.22% faster)
    for func in [item for item in new_node.body if isinstance(item, ast.FunctionDef)]:
        pass

def test_existing_timeout_decorator_not_duplicated():
    # If timeout_decorator.timeout is already present, do not duplicate
    src = '''
class TestDup:
    @timeout_decorator.timeout(15)
    def test_dup(self): pass
    '''
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_dup"),
        module_path="dup.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 12.1μs -> 11.1μs (8.65% faster)
    func = [item for item in new_node.body if isinstance(item, ast.FunctionDef)][0]

# Edge Test Cases

def test_class_with_non_function_body_items():
    # Class with assignments and other nodes in body
    src = '''
class TestEdge:
    x = 1
    def test_func(self): pass
    y = 2
    '''
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_func"),
        module_path="edge.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 28.1μs -> 26.7μs (5.18% faster)
    # Only test_func should get the decorator
    for item in new_node.body:
        if isinstance(item, ast.FunctionDef):
            if item.name == "test_func":
                pass
            else:
                pass
        else:
            pass

def test_class_with_decorators_on_test_function():
    # Test function with other decorators should still get timeout_decorator
    src = '''
class TestDecorators:
    @other_decorator
    def test_deco(self): pass
    '''
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_deco"),
        module_path="deco.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 12.1μs -> 10.9μs (11.1% faster)
    func = [item for item in new_node.body if isinstance(item, ast.FunctionDef)][0]

def test_class_with_inherited_test_functions():
    # Class with test functions that are inherited (should not be touched)
    src = '''
class Base:
    def test_base(self): pass

class Derived(Base):
    def test_derived(self): pass
    '''
    tree = ast.parse(src)
    # Only Derived should be instrumented
    derived = [n for n in tree.body if isinstance(n, ast.ClassDef) and n.name == "Derived"][0]
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_derived"),
        module_path="inherit.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(derived); new_node = codeflash_output # 14.9μs -> 13.7μs (8.64% faster)
    func = [item for item in new_node.body if isinstance(item, ast.FunctionDef)][0]

def test_class_with_test_function_named_exactly_test():
    # Function named 'test' (not 'test_') should not get decorator
    src = '''
class TestExact:
    def test(self): pass
    def test_one(self): pass
    '''
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_one"),
        module_path="exact.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 12.6μs -> 11.8μs (6.43% faster)
    for func in [item for item in new_node.body if isinstance(item, ast.FunctionDef)]:
        if func.name == "test_one":
            pass
        else:
            pass

def test_class_with_async_test_functions():
    # Async test functions should also be instrumented
    src = '''
class TestAsync:
    async def test_async(self): pass
    '''
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_async"),
        module_path="async.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 12.2μs -> 12.1μs (0.577% faster)
    func = [item for item in new_node.body if isinstance(item, ast.AsyncFunctionDef)][0]

# Large Scale Test Cases

def test_large_number_of_test_functions():
    # Class with many test_ functions
    NUM_FUNCS = 500
    src_lines = ["class TestLarge:"]
    for i in range(NUM_FUNCS):
        src_lines.append(f"    def test_func_{i}(self): pass")
    src = "\n".join(src_lines)
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_func_0"),
        module_path="large.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 959μs -> 845μs (13.5% faster)
    test_funcs = [item for item in new_node.body if isinstance(item, ast.FunctionDef)]
    for func in test_funcs:
        pass

def test_large_number_of_non_test_functions():
    # Class with many non-test functions
    NUM_FUNCS = 500
    src_lines = ["class HelperLarge:"]
    for i in range(NUM_FUNCS):
        src_lines.append(f"    def helper_func_{i}(self): pass")
    src = "\n".join(src_lines)
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("helper_func_0"),
        module_path="largehelper.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 299μs -> 294μs (1.85% faster)
    helper_funcs = [item for item in new_node.body if isinstance(item, ast.FunctionDef)]
    for func in helper_funcs:
        pass

def test_large_mixed_class():
    # Class with both test_ and non-test functions
    NUM_TEST_FUNCS = 250
    NUM_HELPER_FUNCS = 250
    src_lines = ["class MixedLarge:"]
    for i in range(NUM_TEST_FUNCS):
        src_lines.append(f"    def test_func_{i}(self): pass")
    for i in range(NUM_HELPER_FUNCS):
        src_lines.append(f"    def helper_func_{i}(self): pass")
    src = "\n".join(src_lines)
    classdef = get_classdef_node(src)
    instrumenter = AsyncCallInstrumenter(
        DummyFunction("test_func_0"),
        module_path="mixedlarge.py",
        test_framework="unittest",
        call_positions=[],
    )
    codeflash_output = instrumenter.visit_ClassDef(classdef); new_node = codeflash_output # 631μs -> 573μs (10.1% faster)
    funcs = [item for item in new_node.body if isinstance(item, ast.FunctionDef)]
    for func in funcs:
        if func.name.startswith("test_"):
            pass
        else:
            pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr769-2025-09-29T21.00.36

Click to see suggested changes
Suggested change
if (
isinstance(item, ast.FunctionDef)
and item.name.startswith("test_")
and not any(
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
for d in item.decorator_list
)
):
item.decorator_list.append(timeout_decorator)
if isinstance(item, ast.FunctionDef) and item.name.startswith("test_"):
has_timeout_decorator = False
for d in item.decorator_list:
if (
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
):
has_timeout_decorator = True
break
if not has_timeout_decorator:
item.decorator_list.append(timeout_decorator)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - but I'll merge this in later.

@KRRT7 KRRT7 merged commit b8e01f3 into main Sep 29, 2025
19 of 22 checks passed
@KRRT7 KRRT7 deleted the clean-async-branch branch September 29, 2025 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Review effort 5/5 workflow-modified This PR modifies GitHub Actions workflows

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants