Skip to content

Commit 1e103bd

Browse files
authored
Merge pull request #687 from codeflash-ai/granular-async-instrumentation
Granular async instrumentation
2 parents 95a149b + 7bbb1e7 commit 1e103bd

39 files changed

+5445
-694
lines changed

.github/workflows/e2e-async.yaml

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
name: E2E - Async
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- '**' # Trigger for all paths
7+
8+
workflow_dispatch:
9+
10+
jobs:
11+
async-optimization:
12+
# Dynamically determine if environment is needed only when workflow files change and contributor is external
13+
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
14+
15+
runs-on: ubuntu-latest
16+
env:
17+
CODEFLASH_AIS_SERVER: prod
18+
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
19+
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
20+
COLUMNS: 110
21+
MAX_RETRIES: 3
22+
RETRY_DELAY: 5
23+
EXPECTED_IMPROVEMENT_PCT: 10
24+
CODEFLASH_END_TO_END: 1
25+
steps:
26+
- name: 🛎️ Checkout
27+
uses: actions/checkout@v4
28+
with:
29+
ref: ${{ github.event.pull_request.head.ref }}
30+
repository: ${{ github.event.pull_request.head.repo.full_name }}
31+
fetch-depth: 0
32+
token: ${{ secrets.GITHUB_TOKEN }}
33+
34+
- name: Validate PR
35+
run: |
36+
# Check for any workflow changes
37+
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
38+
echo "⚠️ Workflow changes detected."
39+
40+
# Get the PR author
41+
AUTHOR="${{ github.event.pull_request.user.login }}"
42+
echo "PR Author: $AUTHOR"
43+
44+
# Allowlist check
45+
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
46+
echo "✅ Authorized user ($AUTHOR). Proceeding."
47+
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
48+
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
49+
else
50+
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
51+
exit 1
52+
fi
53+
else
54+
echo "✅ No workflow file changes detected. Proceeding."
55+
fi
56+
57+
- name: Set up Python 3.11 for CLI
58+
uses: astral-sh/setup-uv@v5
59+
with:
60+
python-version: 3.11.6
61+
62+
- name: Install dependencies (CLI)
63+
run: |
64+
uv sync
65+
66+
- name: Run Codeflash to optimize async code
67+
id: optimize_async_code
68+
run: |
69+
uv run python tests/scripts/end_to_end_test_async.py

.github/workflows/pre-commit.yaml

Lines changed: 0 additions & 19 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.12.7
3+
rev: v0.13.1
44
hooks:
55
# Run the linter.
66
- id: ruff-check
7+
args: [ --config=pyproject.toml ]
78
# Run the formatter.
89
- id: ruff-format
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import asyncio
2+
from typing import List, Union
3+
4+
5+
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
6+
"""
7+
Async bubble sort implementation for testing.
8+
"""
9+
print("codeflash stdout: Async sorting list")
10+
11+
await asyncio.sleep(0.01)
12+
13+
n = len(lst)
14+
for i in range(n):
15+
for j in range(0, n - i - 1):
16+
if lst[j] > lst[j + 1]:
17+
lst[j], lst[j + 1] = lst[j + 1], lst[j]
18+
19+
result = lst.copy()
20+
print(f"result: {result}")
21+
return result
22+
23+
24+
class AsyncBubbleSorter:
25+
"""Class with async sorting method for testing."""
26+
27+
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
28+
"""
29+
Async bubble sort implementation within a class.
30+
"""
31+
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
32+
33+
# Add some async delay
34+
await asyncio.sleep(0.005)
35+
36+
n = len(lst)
37+
for i in range(n):
38+
for j in range(0, n - i - 1):
39+
if lst[j] > lst[j + 1]:
40+
lst[j], lst[j + 1] = lst[j + 1], lst[j]
41+
42+
result = lst.copy()
43+
return result
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import time
2+
import asyncio
3+
4+
5+
async def retry_with_backoff(func, max_retries=3):
6+
if max_retries < 1:
7+
raise ValueError("max_retries must be at least 1")
8+
last_exception = None
9+
for attempt in range(max_retries):
10+
try:
11+
return await func()
12+
except Exception as e:
13+
last_exception = e
14+
if attempt < max_retries - 1:
15+
time.sleep(0.0001 * attempt)
16+
raise last_exception
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[tool.codeflash]
2+
disable-telemetry = true
3+
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
4+
module-root = "."
5+
test-framework = "pytest"
6+
tests-root = "tests"

code_to_optimize/code_directories/async_e2e/tests/__init__.py

Whitespace-only changes.

codeflash.code-workspace

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@
7070
"request": "launch",
7171
"program": "${workspaceFolder:codeflash}/codeflash/main.py",
7272
"args": [
73-
"--all",
73+
"--file",
74+
"src/async_examples/concurrency.py",
75+
"--function",
76+
"task",
77+
"--verbose"
7478
],
7579
"cwd": "${input:chooseCwd}",
7680
"console": "integratedTerminal",

codeflash/api/aiservice.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def optimize_python_code( # noqa: D417
101101
trace_id: str,
102102
num_candidates: int = 10,
103103
experiment_metadata: ExperimentMetadata | None = None,
104+
*,
105+
is_async: bool = False,
104106
) -> list[OptimizedCandidate]:
105107
"""Optimize the given python code for performance by making a request to the Django endpoint.
106108
@@ -131,6 +133,7 @@ def optimize_python_code( # noqa: D417
131133
"current_username": get_last_commit_author_if_pr_exists(None),
132134
"repo_owner": git_repo_owner,
133135
"repo_name": git_repo_name,
136+
"is_async": is_async,
134137
}
135138

136139
logger.info("Generating optimized candidates…")
@@ -295,6 +298,9 @@ def get_new_explanation( # noqa: D417
295298
annotated_tests: str,
296299
optimization_id: str,
297300
original_explanation: str,
301+
original_throughput: str | None = None,
302+
optimized_throughput: str | None = None,
303+
throughput_improvement: str | None = None,
298304
) -> str:
299305
"""Optimize the given python code for performance by making a request to the Django endpoint.
300306
@@ -311,6 +317,9 @@ def get_new_explanation( # noqa: D417
311317
- annotated_tests: str - test functions annotated with runtime
312318
- optimization_id: str - unique id of opt candidate
313319
- original_explanation: str - original_explanation generated for the opt candidate
320+
- original_throughput: str | None - throughput for the baseline code (operations per second)
321+
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
322+
- throughput_improvement: str | None - throughput improvement percentage
314323
315324
Returns
316325
-------
@@ -330,6 +339,9 @@ def get_new_explanation( # noqa: D417
330339
"optimization_id": optimization_id,
331340
"original_explanation": original_explanation,
332341
"dependency_code": dependency_code,
342+
"original_throughput": original_throughput,
343+
"optimized_throughput": optimized_throughput,
344+
"throughput_improvement": throughput_improvement,
333345
}
334346
logger.info("Generating explanation")
335347
console.rule()
@@ -439,7 +451,9 @@ def generate_regression_tests( # noqa: D417
439451
"test_index": test_index,
440452
"python_version": platform.python_version(),
441453
"codeflash_version": codeflash_version,
454+
"is_async": function_to_optimize.is_async,
442455
}
456+
443457
try:
444458
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
445459
except requests.exceptions.RequestException as e:

codeflash/code_utils/code_extractor.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
272272
if child.module is None:
273273
continue
274274
module = self.get_full_dotted_name(child.module)
275+
if isinstance(child.names, cst.ImportStar):
276+
continue
275277
for alias in child.names:
276278
if isinstance(alias, cst.ImportAlias):
277279
name = alias.name.value
@@ -403,6 +405,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
403405
return transformed_module.code
404406

405407

408+
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
409+
try:
410+
module_path = module_name.replace(".", "/")
411+
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
412+
413+
module_file = None
414+
for path in possible_paths:
415+
if path.exists():
416+
module_file = path
417+
break
418+
419+
if module_file is None:
420+
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
421+
return set()
422+
423+
with module_file.open(encoding="utf8") as f:
424+
module_code = f.read()
425+
426+
tree = ast.parse(module_code)
427+
428+
all_names = None
429+
for node in ast.walk(tree):
430+
if (
431+
isinstance(node, ast.Assign)
432+
and len(node.targets) == 1
433+
and isinstance(node.targets[0], ast.Name)
434+
and node.targets[0].id == "__all__"
435+
):
436+
if isinstance(node.value, (ast.List, ast.Tuple)):
437+
all_names = []
438+
for elt in node.value.elts:
439+
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
440+
all_names.append(elt.value)
441+
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
442+
all_names.append(elt.s)
443+
break
444+
445+
if all_names is not None:
446+
return set(all_names)
447+
448+
public_names = set()
449+
for node in tree.body:
450+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
451+
if not node.name.startswith("_"):
452+
public_names.add(node.name)
453+
elif isinstance(node, ast.Assign):
454+
for target in node.targets:
455+
if isinstance(target, ast.Name) and not target.id.startswith("_"):
456+
public_names.add(target.id)
457+
elif isinstance(node, ast.AnnAssign):
458+
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
459+
public_names.add(node.target.id)
460+
elif isinstance(node, ast.Import) or (
461+
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
462+
):
463+
for alias in node.names:
464+
name = alias.asname or alias.name
465+
if not name.startswith("_"):
466+
public_names.add(name)
467+
468+
return public_names # noqa: TRY300
469+
470+
except Exception as e:
471+
logger.warning(f"Error resolving star import for {module_name}: {e}")
472+
return set()
473+
474+
406475
def add_needed_imports_from_module(
407476
src_module_code: str,
408477
dst_module_code: str,
@@ -457,9 +526,23 @@ def add_needed_imports_from_module(
457526
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
458527
):
459528
continue # Skip adding imports for helper functions already in the context
460-
if f"{mod}.{obj}" not in dotted_import_collector.imports:
461-
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
462-
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
529+
530+
# Handle star imports by resolving them to actual symbol names
531+
if obj == "*":
532+
resolved_symbols = resolve_star_import(mod, project_root)
533+
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
534+
535+
for symbol in resolved_symbols:
536+
if (
537+
f"{mod}.{symbol}" not in helper_functions_fqn
538+
and f"{mod}.{symbol}" not in dotted_import_collector.imports
539+
):
540+
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
541+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
542+
else:
543+
if f"{mod}.{obj}" not in dotted_import_collector.imports:
544+
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
545+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
463546
except Exception as e:
464547
logger.exception(f"Error adding imports to destination module code: {e}")
465548
return dst_module_code

0 commit comments

Comments
 (0)