Skip to content

Commit 9aa34d9

Browse files
authored
Merge pull request #752 from codeflash-ai/end-to-end-test
add End to end test for async optimization
2 parents d8849a5 + 43615bd commit 9aa34d9

File tree

6 files changed

+153
-14
lines changed

6 files changed

+153
-14
lines changed

.github/workflows/e2e-async.yaml

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
name: E2E - Async
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- '**' # Trigger for all paths
7+
8+
workflow_dispatch:
9+
10+
jobs:
11+
async-optimization:
12+
# Dynamically determine if environment is needed only when workflow files change and contributor is external
13+
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
14+
15+
runs-on: ubuntu-latest
16+
env:
17+
CODEFLASH_AIS_SERVER: prod
18+
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
19+
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
20+
COLUMNS: 110
21+
MAX_RETRIES: 3
22+
RETRY_DELAY: 5
23+
EXPECTED_IMPROVEMENT_PCT: 10
24+
CODEFLASH_END_TO_END: 1
25+
steps:
26+
- name: 🛎️ Checkout
27+
uses: actions/checkout@v4
28+
with:
29+
ref: ${{ github.event.pull_request.head.ref }}
30+
repository: ${{ github.event.pull_request.head.repo.full_name }}
31+
fetch-depth: 0
32+
token: ${{ secrets.GITHUB_TOKEN }}
33+
34+
- name: Validate PR
35+
run: |
36+
# Check for any workflow changes
37+
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
38+
echo "⚠️ Workflow changes detected."
39+
40+
# Get the PR author
41+
AUTHOR="${{ github.event.pull_request.user.login }}"
42+
echo "PR Author: $AUTHOR"
43+
44+
# Allowlist check
45+
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
46+
echo "✅ Authorized user ($AUTHOR). Proceeding."
47+
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
48+
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
49+
else
50+
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
51+
exit 1
52+
fi
53+
else
54+
echo "✅ No workflow file changes detected. Proceeding."
55+
fi
56+
57+
- name: Set up Python 3.11 for CLI
58+
uses: astral-sh/setup-uv@v5
59+
with:
60+
python-version: 3.11.6
61+
62+
- name: Install dependencies (CLI)
63+
run: |
64+
uv sync
65+
66+
- name: Run Codeflash to optimize async code
67+
id: optimize_async_code
68+
run: |
69+
uv run python tests/scripts/end_to_end_test_async.py
Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
11
import time
2-
async def fake_api_call(delay, data):
3-
time.sleep(0.0001)
4-
return f"Processed: {data}"
2+
import asyncio
3+
4+
5+
async def retry_with_backoff(func, max_retries=3):
6+
if max_retries < 1:
7+
raise ValueError("max_retries must be at least 1")
8+
last_exception = None
9+
for attempt in range(max_retries):
10+
try:
11+
return await func()
12+
except Exception as e:
13+
last_exception = e
14+
if attempt < max_retries - 1:
15+
time.sleep(0.0001 * attempt)
16+
raise last_exception

codeflash/optimization/function_optimizer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,15 @@ def determine_best_candidate(
658658
)
659659
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
660660
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
661+
if (
662+
original_code_baseline.async_throughput is not None
663+
and candidate_result.async_throughput is not None
664+
):
665+
throughput_gain_value = throughput_gain(
666+
original_throughput=original_code_baseline.async_throughput,
667+
optimized_throughput=candidate_result.async_throughput,
668+
)
669+
tree.add(f"Throughput gain: {throughput_gain_value * 100:.1f}%")
661670
console.print(tree)
662671
if self.args.benchmark and benchmark_tree:
663672
console.print(benchmark_tree)
@@ -1199,6 +1208,8 @@ def find_and_process_best_optimization(
11991208
function_name=function_to_optimize_qualified_name,
12001209
file_path=self.function_to_optimize.file_path,
12011210
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
1211+
original_async_throughput=original_code_baseline.async_throughput,
1212+
best_async_throughput=best_optimization.async_throughput,
12021213
)
12031214

12041215
self.replace_function_and_helpers_with_optimized_code(
@@ -1284,7 +1295,7 @@ def process_review(
12841295
original_throughput_str = None
12851296
optimized_throughput_str = None
12861297
throughput_improvement_str = None
1287-
1298+
12881299
if (
12891300
self.function_to_optimize.is_async
12901301
and original_code_baseline.async_throughput is not None
@@ -1297,7 +1308,7 @@ def process_review(
12971308
optimized_throughput=best_optimization.async_throughput,
12981309
)
12991310
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
1300-
1311+
13011312
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
13021313
source_code=code_context.read_writable_code.flat,
13031314
dependency_code=code_context.read_only_context_code,
@@ -1324,6 +1335,8 @@ def process_review(
13241335
function_name=explanation.function_name,
13251336
file_path=explanation.file_path,
13261337
benchmark_details=explanation.benchmark_details,
1338+
original_async_throughput=explanation.original_async_throughput,
1339+
best_async_throughput=explanation.best_async_throughput,
13271340
)
13281341
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
13291342

@@ -1551,7 +1564,8 @@ def establish_original_code_baseline(
15511564
async_throughput = calculate_function_throughput_from_test_results(
15521565
benchmarking_results, self.function_to_optimize.function_name
15531566
)
1554-
logger.info(f"Original async function throughput: {async_throughput} calls/second")
1567+
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
1568+
console.rule()
15551569

15561570
if self.args.benchmark:
15571571
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(

codeflash/result/critic.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ def speedup_critic(
8282
original_throughput=original_async_throughput, optimized_throughput=candidate_result.async_throughput
8383
)
8484
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
85-
logger.info(
86-
f"Async throughput gain: {throughput_gain_value * 100:.1f}% (original: {original_async_throughput}, optimized: {candidate_result.async_throughput})"
87-
)
8885

8986
throughput_is_best = (
9087
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now

codeflash/result/explanation.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from codeflash.code_utils.time_utils import humanize_runtime
1313
from codeflash.models.models import BenchmarkDetail, TestResults
14+
from codeflash.result.critic import performance_gain, throughput_gain
1415

1516

1617
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
@@ -23,9 +24,29 @@ class Explanation:
2324
function_name: str
2425
file_path: Path
2526
benchmark_details: Optional[list[BenchmarkDetail]] = None
27+
original_async_throughput: Optional[int] = None
28+
best_async_throughput: Optional[int] = None
2629

2730
@property
2831
def perf_improvement_line(self) -> str:
32+
runtime_improvement = self.speedup
33+
34+
if (
35+
self.original_async_throughput is not None
36+
and self.best_async_throughput is not None
37+
and self.original_async_throughput > 0
38+
):
39+
throughput_improvement = throughput_gain(
40+
original_throughput=self.original_async_throughput,
41+
optimized_throughput=self.best_async_throughput,
42+
)
43+
44+
# Use throughput metrics if throughput improvement is better or runtime got worse
45+
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
46+
throughput_pct = f"{throughput_improvement * 100:,.0f}%"
47+
throughput_x = f"{throughput_improvement + 1:,.2f}x"
48+
return f"{throughput_pct} improvement ({throughput_x} faster)."
49+
2950
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
3051

3152
@property
@@ -45,6 +66,24 @@ def to_console_string(self) -> str:
4566
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
4667
original_runtime_human = humanize_runtime(self.original_runtime_ns)
4768
best_runtime_human = humanize_runtime(self.best_runtime_ns)
69+
70+
# Determine if we're showing throughput or runtime improvements
71+
runtime_improvement = self.speedup
72+
is_using_throughput_metric = False
73+
74+
if (
75+
self.original_async_throughput is not None
76+
and self.best_async_throughput is not None
77+
and self.original_async_throughput > 0
78+
):
79+
throughput_improvement = throughput_gain(
80+
original_throughput=self.original_async_throughput,
81+
optimized_throughput=self.best_async_throughput,
82+
)
83+
84+
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
85+
is_using_throughput_metric = True
86+
4887
benchmark_info = ""
4988

5089
if self.benchmark_details:
@@ -85,10 +124,18 @@ def to_console_string(self) -> str:
85124
console.print(table)
86125
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
87126

127+
if is_using_throughput_metric:
128+
performance_description = (
129+
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
130+
f"(runtime: {original_runtime_human}{best_runtime_human})\n\n"
131+
)
132+
else:
133+
performance_description = f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
134+
88135
return (
89136
f"Optimized {self.function_name} in {self.file_path}\n"
90137
f"{self.perf_improvement_line}\n"
91-
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
138+
+ performance_description
92139
+ (benchmark_info if benchmark_info else "")
93140
+ self.raw_explanation_message
94141
+ " \n\n"

tests/scripts/end_to_end_test_async.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
def run_test(expected_improvement_pct: int) -> bool:
88
config = TestConfig(
9-
file_path="workload.py",
10-
expected_unit_tests=1,
9+
file_path="main.py",
10+
expected_unit_tests=0,
1111
min_improvement_x=0.1,
1212
coverage_expectations=[
1313
CoverageExpectation(
14-
function_name="process_data_list",
14+
function_name="retry_with_backoff",
1515
expected_coverage=100.0,
16-
expected_lines=[5, 7, 8, 9, 10, 12],
16+
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
1717
)
1818
],
1919
)

0 commit comments

Comments
 (0)