Skip to content

Commit f978a40

Browse files
committed
Merge branch 'main' into part-1-windows-fixes
2 parents 2c504ee + 565d65b commit f978a40

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+5835
-839
lines changed

.github/workflows/e2e-async.yaml

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
name: E2E - Async
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- '**' # Trigger for all paths
7+
8+
workflow_dispatch:
9+
10+
jobs:
11+
async-optimization:
12+
# Dynamically determine if environment is needed only when workflow files change and contributor is external
13+
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
14+
15+
runs-on: ubuntu-latest
16+
env:
17+
CODEFLASH_AIS_SERVER: prod
18+
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
19+
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
20+
COLUMNS: 110
21+
MAX_RETRIES: 3
22+
RETRY_DELAY: 5
23+
EXPECTED_IMPROVEMENT_PCT: 10
24+
CODEFLASH_END_TO_END: 1
25+
steps:
26+
- name: 🛎️ Checkout
27+
uses: actions/checkout@v4
28+
with:
29+
ref: ${{ github.event.pull_request.head.ref }}
30+
repository: ${{ github.event.pull_request.head.repo.full_name }}
31+
fetch-depth: 0
32+
token: ${{ secrets.GITHUB_TOKEN }}
33+
34+
- name: Validate PR
35+
run: |
36+
# Check for any workflow changes
37+
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
38+
echo "⚠️ Workflow changes detected."
39+
40+
# Get the PR author
41+
AUTHOR="${{ github.event.pull_request.user.login }}"
42+
echo "PR Author: $AUTHOR"
43+
44+
# Allowlist check
45+
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
46+
echo "✅ Authorized user ($AUTHOR). Proceeding."
47+
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
48+
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
49+
else
50+
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
51+
exit 1
52+
fi
53+
else
54+
echo "✅ No workflow file changes detected. Proceeding."
55+
fi
56+
57+
- name: Set up Python 3.11 for CLI
58+
uses: astral-sh/setup-uv@v5
59+
with:
60+
python-version: 3.11.6
61+
62+
- name: Install dependencies (CLI)
63+
run: |
64+
uv sync
65+
66+
- name: Run Codeflash to optimize async code
67+
id: optimize_async_code
68+
run: |
69+
uv run python tests/scripts/end_to_end_test_async.py

.github/workflows/e2e-bubblesort-pytest-nogit.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
COLUMNS: 110
2121
MAX_RETRIES: 3
2222
RETRY_DELAY: 5
23-
EXPECTED_IMPROVEMENT_PCT: 300
23+
EXPECTED_IMPROVEMENT_PCT: 70
2424
CODEFLASH_END_TO_END: 1
2525
steps:
2626
- name: 🛎️ Checkout

.github/workflows/e2e-bubblesort-unittest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
COLUMNS: 110
2121
MAX_RETRIES: 3
2222
RETRY_DELAY: 5
23-
EXPECTED_IMPROVEMENT_PCT: 300
23+
EXPECTED_IMPROVEMENT_PCT: 40
2424
CODEFLASH_END_TO_END: 1
2525
steps:
2626
- name: 🛎️ Checkout
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import asyncio
2+
from typing import List, Union
3+
4+
5+
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
6+
"""
7+
Async bubble sort implementation for testing.
8+
"""
9+
print("codeflash stdout: Async sorting list")
10+
11+
await asyncio.sleep(0.01)
12+
13+
n = len(lst)
14+
for i in range(n):
15+
for j in range(0, n - i - 1):
16+
if lst[j] > lst[j + 1]:
17+
lst[j], lst[j + 1] = lst[j + 1], lst[j]
18+
19+
result = lst.copy()
20+
print(f"result: {result}")
21+
return result
22+
23+
24+
class AsyncBubbleSorter:
25+
"""Class with async sorting method for testing."""
26+
27+
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
28+
"""
29+
Async bubble sort implementation within a class.
30+
"""
31+
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
32+
33+
# Add some async delay
34+
await asyncio.sleep(0.005)
35+
36+
n = len(lst)
37+
for i in range(n):
38+
for j in range(0, n - i - 1):
39+
if lst[j] > lst[j + 1]:
40+
lst[j], lst[j + 1] = lst[j + 1], lst[j]
41+
42+
result = lst.copy()
43+
return result
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import time
2+
import asyncio
3+
4+
5+
async def retry_with_backoff(func, max_retries=3):
6+
if max_retries < 1:
7+
raise ValueError("max_retries must be at least 1")
8+
last_exception = None
9+
for attempt in range(max_retries):
10+
try:
11+
return await func()
12+
except Exception as e:
13+
last_exception = e
14+
if attempt < max_retries - 1:
15+
time.sleep(0.0001 * attempt)
16+
raise last_exception
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[tool.codeflash]
2+
disable-telemetry = true
3+
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
4+
module-root = "."
5+
test-framework = "pytest"
6+
tests-root = "tests"

code_to_optimize/code_directories/async_e2e/tests/__init__.py

Whitespace-only changes.

codeflash/api/aiservice.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def optimize_python_code( # noqa: D417
102102
trace_id: str,
103103
num_candidates: int = 10,
104104
experiment_metadata: ExperimentMetadata | None = None,
105+
*,
106+
is_async: bool = False,
105107
) -> list[OptimizedCandidate]:
106108
"""Optimize the given python code for performance by making a request to the Django endpoint.
107109
@@ -133,6 +135,7 @@ def optimize_python_code( # noqa: D417
133135
"repo_owner": git_repo_owner,
134136
"repo_name": git_repo_name,
135137
"n_candidates": N_CANDIDATES_EFFECTIVE,
138+
"is_async": is_async,
136139
}
137140

138141
logger.info("!lsp|Generating optimized candidates…")
@@ -299,6 +302,9 @@ def get_new_explanation( # noqa: D417
299302
annotated_tests: str,
300303
optimization_id: str,
301304
original_explanation: str,
305+
original_throughput: str | None = None,
306+
optimized_throughput: str | None = None,
307+
throughput_improvement: str | None = None,
302308
) -> str:
303309
"""Optimize the given python code for performance by making a request to the Django endpoint.
304310
@@ -315,6 +321,9 @@ def get_new_explanation( # noqa: D417
315321
- annotated_tests: str - test functions annotated with runtime
316322
- optimization_id: str - unique id of opt candidate
317323
- original_explanation: str - original_explanation generated for the opt candidate
324+
- original_throughput: str | None - throughput for the baseline code (operations per second)
325+
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
326+
- throughput_improvement: str | None - throughput improvement percentage
318327
319328
Returns
320329
-------
@@ -334,6 +343,9 @@ def get_new_explanation( # noqa: D417
334343
"optimization_id": optimization_id,
335344
"original_explanation": original_explanation,
336345
"dependency_code": dependency_code,
346+
"original_throughput": original_throughput,
347+
"optimized_throughput": optimized_throughput,
348+
"throughput_improvement": throughput_improvement,
337349
}
338350
logger.info("loading|Generating explanation")
339351
console.rule()
@@ -488,6 +500,7 @@ def generate_regression_tests( # noqa: D417
488500
"test_index": test_index,
489501
"python_version": platform.python_version(),
490502
"codeflash_version": codeflash_version,
503+
"is_async": function_to_optimize.is_async,
491504
}
492505
try:
493506
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)

codeflash/cli_cmds/cli.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.util
12
import logging
23
import sys
34
from argparse import SUPPRESS, ArgumentParser, Namespace
@@ -96,6 +97,12 @@ def parse_args() -> Namespace:
9697
)
9798
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
9899
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
100+
parser.add_argument(
101+
"--async",
102+
default=False,
103+
action="store_true",
104+
help="Enable optimization of async functions. By default, async functions are excluded from optimization.",
105+
)
99106

100107
args, unknown_args = parser.parse_known_args()
101108
sys.argv[:] = [sys.argv[0], *unknown_args]
@@ -139,6 +146,14 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
139146
if env_utils.is_ci():
140147
args.no_pr = True
141148

149+
if getattr(args, "async", False) and importlib.util.find_spec("pytest_asyncio") is None:
150+
logger.warning(
151+
"Warning: The --async flag requires pytest-asyncio to be installed.\n"
152+
"Please install it using:\n"
153+
' pip install "codeflash[asyncio]"'
154+
)
155+
raise SystemExit(1)
156+
142157
return args
143158

144159

codeflash/code_utils/code_extractor.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
272272
if child.module is None:
273273
continue
274274
module = self.get_full_dotted_name(child.module)
275+
if isinstance(child.names, cst.ImportStar):
276+
continue
275277
for alias in child.names:
276278
if isinstance(alias, cst.ImportAlias):
277279
name = alias.name.value
@@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
414416
return transformed_module.code
415417

416418

419+
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
420+
try:
421+
module_path = module_name.replace(".", "/")
422+
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
423+
424+
module_file = None
425+
for path in possible_paths:
426+
if path.exists():
427+
module_file = path
428+
break
429+
430+
if module_file is None:
431+
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
432+
return set()
433+
434+
with module_file.open(encoding="utf8") as f:
435+
module_code = f.read()
436+
437+
tree = ast.parse(module_code)
438+
439+
all_names = None
440+
for node in ast.walk(tree):
441+
if (
442+
isinstance(node, ast.Assign)
443+
and len(node.targets) == 1
444+
and isinstance(node.targets[0], ast.Name)
445+
and node.targets[0].id == "__all__"
446+
):
447+
if isinstance(node.value, (ast.List, ast.Tuple)):
448+
all_names = []
449+
for elt in node.value.elts:
450+
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
451+
all_names.append(elt.value)
452+
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
453+
all_names.append(elt.s)
454+
break
455+
456+
if all_names is not None:
457+
return set(all_names)
458+
459+
public_names = set()
460+
for node in tree.body:
461+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
462+
if not node.name.startswith("_"):
463+
public_names.add(node.name)
464+
elif isinstance(node, ast.Assign):
465+
for target in node.targets:
466+
if isinstance(target, ast.Name) and not target.id.startswith("_"):
467+
public_names.add(target.id)
468+
elif isinstance(node, ast.AnnAssign):
469+
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
470+
public_names.add(node.target.id)
471+
elif isinstance(node, ast.Import) or (
472+
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
473+
):
474+
for alias in node.names:
475+
name = alias.asname or alias.name
476+
if not name.startswith("_"):
477+
public_names.add(name)
478+
479+
return public_names # noqa: TRY300
480+
481+
except Exception as e:
482+
logger.warning(f"Error resolving star import for {module_name}: {e}")
483+
return set()
484+
485+
417486
def add_needed_imports_from_module(
418487
src_module_code: str,
419488
dst_module_code: str,
@@ -468,9 +537,23 @@ def add_needed_imports_from_module(
468537
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
469538
):
470539
continue # Skip adding imports for helper functions already in the context
471-
if f"{mod}.{obj}" not in dotted_import_collector.imports:
472-
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
473-
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
540+
541+
# Handle star imports by resolving them to actual symbol names
542+
if obj == "*":
543+
resolved_symbols = resolve_star_import(mod, project_root)
544+
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
545+
546+
for symbol in resolved_symbols:
547+
if (
548+
f"{mod}.{symbol}" not in helper_functions_fqn
549+
and f"{mod}.{symbol}" not in dotted_import_collector.imports
550+
):
551+
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
552+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
553+
else:
554+
if f"{mod}.{obj}" not in dotted_import_collector.imports:
555+
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
556+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
474557
except Exception as e:
475558
logger.exception(f"Error adding imports to destination module code: {e}")
476559
return dst_module_code

0 commit comments

Comments
 (0)