Skip to content

Commit 2f46cc7

Browse files
authored
Merge pull request #990 from codeflash-ai/diversity
let's add diversity to our optimizations
2 parents 7606c8e + 1be8302 commit 2f46cc7

File tree

4 files changed

+58
-37
lines changed

4 files changed

+58
-37
lines changed

codeflash/api/aiservice.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import platform
66
import time
7+
from itertools import count
78
from typing import TYPE_CHECKING, Any, cast
89

910
import requests
@@ -12,7 +13,6 @@
1213
from codeflash.cli_cmds.console import console, logger
1314
from codeflash.code_utils.code_replacer import is_zero_diff
1415
from codeflash.code_utils.code_utils import unified_diff_strings
15-
from codeflash.code_utils.config_consts import N_CANDIDATES_EFFECTIVE, N_CANDIDATES_LP_EFFECTIVE
1616
from codeflash.code_utils.env_utils import get_codeflash_api_key
1717
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1818
from codeflash.code_utils.time_utils import humanize_runtime
@@ -40,6 +40,11 @@ class AiServiceClient:
4040
def __init__(self) -> None:
4141
self.base_url = self.get_aiservice_base_url()
4242
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
43+
self.llm_call_counter = count(1)
44+
45+
def get_next_sequence(self) -> int:
46+
"""Get the next LLM call sequence number."""
47+
return next(self.llm_call_counter)
4348

4449
def get_aiservice_base_url(self) -> str:
4550
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
@@ -106,6 +111,7 @@ def _get_valid_candidates(
106111
optimization_id=opt["optimization_id"],
107112
source=source,
108113
parent_id=opt.get("parent_id", None),
114+
model=opt.get("model"),
109115
)
110116
)
111117
return candidates
@@ -115,7 +121,6 @@ def optimize_python_code( # noqa: D417
115121
source_code: str,
116122
dependency_code: str,
117123
trace_id: str,
118-
num_candidates: int = 10,
119124
experiment_metadata: ExperimentMetadata | None = None,
120125
*,
121126
is_async: bool = False,
@@ -127,46 +132,49 @@ def optimize_python_code( # noqa: D417
127132
- source_code (str): The python code to optimize.
128133
- dependency_code (str): The dependency code used as read-only context for the optimization
129134
- trace_id (str): Trace id of optimization run
130-
- num_candidates (int): Number of optimization variants to generate. Default is 10.
131135
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
136+
- is_async (bool): Whether the function being optimized is async
132137
133138
Returns
134139
-------
135140
- List[OptimizationCandidate]: A list of Optimization Candidates.
136141
137142
"""
143+
logger.info("Generating optimized candidates…")
144+
console.rule()
138145
start_time = time.perf_counter()
139146
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
140147

141148
payload = {
142149
"source_code": source_code,
143150
"dependency_code": dependency_code,
144-
"num_variants": num_candidates,
145151
"trace_id": trace_id,
146152
"python_version": platform.python_version(),
147153
"experiment_metadata": experiment_metadata,
148154
"codeflash_version": codeflash_version,
149155
"current_username": get_last_commit_author_if_pr_exists(None),
150156
"repo_owner": git_repo_owner,
151157
"repo_name": git_repo_name,
152-
"n_candidates": N_CANDIDATES_EFFECTIVE,
153158
"is_async": is_async,
159+
"lsp_mode": is_LSP_enabled(),
160+
"call_sequence": self.get_next_sequence(),
154161
}
162+
logger.debug(f"Sending optimize request: trace_id={trace_id}, lsp_mode={payload['lsp_mode']}")
155163

156-
logger.info("!lsp|Generating optimized candidates…")
157-
console.rule()
158164
try:
159165
response = self.make_ai_service_request("/optimize", payload=payload, timeout=60)
160166
except requests.exceptions.RequestException as e:
161167
logger.exception(f"Error generating optimized candidates: {e}")
162168
ph("cli-optimize-error-caught", {"error": str(e)})
169+
console.rule()
163170
return []
164171

165172
if response.status_code == 200:
166173
optimizations_json = response.json()["optimizations"]
167-
console.rule()
168174
end_time = time.perf_counter()
169175
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")
176+
logger.info(f"!lsp|Received {len(optimizations_json)} optimization candidates.")
177+
console.rule()
170178
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE)
171179
try:
172180
error = response.json()["error"]
@@ -183,54 +191,53 @@ def optimize_python_code_line_profiler( # noqa: D417
183191
dependency_code: str,
184192
trace_id: str,
185193
line_profiler_results: str,
186-
num_candidates: int = 10,
187194
experiment_metadata: ExperimentMetadata | None = None,
188195
) -> list[OptimizedCandidate]:
189-
"""Optimize the given python code for performance by making a request to the Django endpoint.
196+
"""Optimize the given python code for performance using line profiler results.
190197
191198
Parameters
192199
----------
193200
- source_code (str): The python code to optimize.
194201
- dependency_code (str): The dependency code used as read-only context for the optimization
195202
- trace_id (str): Trace id of optimization run
196-
- num_candidates (int): Number of optimization variants to generate. Default is 10.
203+
- line_profiler_results (str): Line profiler output to guide optimization
197204
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
198205
199206
Returns
200207
-------
201208
- List[OptimizationCandidate]: A list of Optimization Candidates.
202209
203210
"""
211+
if line_profiler_results == "":
212+
logger.info("No LineProfiler results were provided, Skipping optimization.")
213+
return []
214+
215+
logger.info("Generating optimized candidates with line profiler…")
216+
console.rule()
217+
204218
payload = {
205219
"source_code": source_code,
206220
"dependency_code": dependency_code,
207-
"num_variants": num_candidates,
208221
"line_profiler_results": line_profiler_results,
209222
"trace_id": trace_id,
210223
"python_version": platform.python_version(),
211224
"experiment_metadata": experiment_metadata,
212225
"codeflash_version": codeflash_version,
213226
"lsp_mode": is_LSP_enabled(),
214-
"n_candidates_lp": N_CANDIDATES_LP_EFFECTIVE,
227+
"call_sequence": self.get_next_sequence(),
215228
}
216229

217-
console.rule()
218-
if line_profiler_results == "":
219-
logger.info("No LineProfiler results were provided, Skipping optimization.")
220-
console.rule()
221-
return []
222230
try:
223231
response = self.make_ai_service_request("/optimize-line-profiler", payload=payload, timeout=60)
224232
except requests.exceptions.RequestException as e:
225233
logger.exception(f"Error generating optimized candidates: {e}")
226234
ph("cli-optimize-error-caught", {"error": str(e)})
235+
console.rule()
227236
return []
228237

229238
if response.status_code == 200:
230239
optimizations_json = response.json()["optimizations"]
231-
logger.info(
232-
f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information."
233-
)
240+
logger.info(f"!lsp|Received {len(optimizations_json)} line profiler optimization candidates.")
234241
console.rule()
235242
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP)
236243
try:
@@ -268,6 +275,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
268275
"trace_id": opt.trace_id,
269276
"function_references": opt.function_references,
270277
"python_version": platform.python_version(),
278+
"call_sequence": self.get_next_sequence(),
271279
}
272280
for opt in request
273281
]
@@ -402,6 +410,7 @@ def get_new_explanation( # noqa: D417
402410
"throughput_improvement": throughput_improvement,
403411
"function_references": function_references,
404412
"codeflash_version": codeflash_version,
413+
"call_sequence": self.get_next_sequence(),
405414
}
406415
logger.info("loading|Generating explanation")
407416
console.rule()
@@ -564,6 +573,7 @@ def generate_regression_tests( # noqa: D417
564573
"python_version": platform.python_version(),
565574
"codeflash_version": codeflash_version,
566575
"is_async": function_to_optimize.is_async,
576+
"call_sequence": self.get_next_sequence(),
567577
}
568578
try:
569579
response = self.make_ai_service_request("/testgen", payload=payload, timeout=90)
@@ -650,6 +660,7 @@ def get_optimization_review(
650660
"codeflash_version": codeflash_version,
651661
"calling_fn_details": calling_fn_details,
652662
"python_version": platform.python_version(),
663+
"call_sequence": self.get_next_sequence(),
653664
}
654665
console.rule()
655666
try:

codeflash/discovery/discover_unit_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ def process_test_files(
751751

752752
tests_cache = TestsCache(project_root_path)
753753
logger.info("!lsp|Discovering tests and processing unit tests")
754+
console.rule()
754755
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
755756
progress,
756757
task_id,

codeflash/models/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class AIServiceRefinerRequest:
4646
original_line_profiler_results: str
4747
optimized_line_profiler_results: str
4848
function_references: str | None = None
49+
call_sequence: int | None = None
4950

5051

5152
class TestDiffScope(str, Enum):
@@ -464,6 +465,7 @@ class OptimizedCandidate:
464465
optimization_id: str
465466
source: OptimizedCandidateSource
466467
parent_id: str | None = None
468+
model: str | None = None # Which LLM model generated this candidate
467469

468470

469471
@dataclass(frozen=True)

codeflash/optimization/function_optimizer.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@
4646
COVERAGE_THRESHOLD,
4747
INDIVIDUAL_TESTCASE_TIMEOUT,
4848
MAX_REPAIRS_PER_TRACE,
49-
N_CANDIDATES_EFFECTIVE,
50-
N_CANDIDATES_LP_EFFECTIVE,
5149
N_TESTS_TO_GENERATE_EFFECTIVE,
5250
REFINE_ALL_THRESHOLD,
5351
REFINED_CANDIDATE_RANKING_WEIGHTS,
@@ -146,6 +144,7 @@ def __init__(
146144
self.candidate_len = len(initial_candidates)
147145
self.ai_service_client = ai_service_client
148146
self.executor = executor
147+
self.refinement_calls_count = 0
149148

150149
# Initialize queue with initial candidates
151150
for candidate in initial_candidates:
@@ -155,6 +154,9 @@ def __init__(
155154
self.all_refinements_data = all_refinements_data
156155
self.future_all_code_repair = future_all_code_repair
157156

157+
def get_total_llm_calls(self) -> int:
158+
return self.refinement_calls_count
159+
158160
def get_next_candidate(self) -> OptimizedCandidate | None:
159161
"""Get the next candidate from the queue, handling async results as needed."""
160162
try:
@@ -193,10 +195,12 @@ def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concur
193195
def _process_refinement_results(self) -> OptimizedCandidate | None:
194196
"""Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
195197
future_refinements: list[concurrent.futures.Future] = []
198+
refinement_call_index = 0
196199

197200
if len(self.all_refinements_data) <= REFINE_ALL_THRESHOLD:
198201
for data in self.all_refinements_data:
199-
future_refinements.append(self.refine_optimizations([data])) # noqa: PERF401
202+
refinement_call_index += 1
203+
future_refinements.append(self.refine_optimizations([data]))
200204
else:
201205
diff_lens_list = []
202206
runtimes_list = []
@@ -215,9 +219,13 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
215219
top_indecies = sorted(score_dict, key=score_dict.get)[:top_n_candidates]
216220

217221
for idx in top_indecies:
222+
refinement_call_index += 1
218223
data = self.all_refinements_data[idx]
219224
future_refinements.append(self.refine_optimizations([data]))
220225

226+
# Track total refinement calls made
227+
self.refinement_calls_count = refinement_call_index
228+
221229
if future_refinements:
222230
logger.info("loading|Refining generated code for improved quality and performance...")
223231

@@ -237,6 +245,7 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
237245
logger.info(
238246
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
239247
)
248+
console.rule()
240249
self.refinement_done = True
241250

242251
return self.get_next_candidate()
@@ -322,7 +331,7 @@ def __init__(
322331

323332
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
324333
should_run_experiment = self.experiment_id is not None
325-
logger.debug(f"Function Trace ID: {self.function_trace_id}")
334+
logger.info(f"Function Trace ID: {self.function_trace_id}")
326335
ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id})
327336
self.cleanup_leftover_test_return_values()
328337
file_name_from_test_module_name.cache_clear()
@@ -927,7 +936,6 @@ def determine_best_candidate(
927936
dependency_code=code_context.read_only_context_code,
928937
trace_id=self.get_trace_id(exp_type),
929938
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
930-
num_candidates=N_CANDIDATES_LP_EFFECTIVE,
931939
experiment_metadata=ExperimentMetadata(
932940
id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment"
933941
)
@@ -1206,7 +1214,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
12061214
func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
12071215
if func_qualname not in function_to_all_tests:
12081216
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
1209-
console.rule()
12101217
else:
12111218
test_file_invocation_positions = defaultdict(list)
12121219
for tests_in_file in function_to_all_tests.get(func_qualname):
@@ -1346,7 +1353,8 @@ def generate_tests(
13461353
if concolic_test_str:
13471354
count_tests += 1
13481355

1349-
logger.info(f"!lsp|Generated '{count_tests}' tests for '{self.function_to_optimize.function_name}'")
1356+
logger.info(f"!lsp|Generated {count_tests} tests for '{self.function_to_optimize.function_name}'")
1357+
console.rule()
13501358

13511359
generated_tests = GeneratedTestsList(generated_tests=tests)
13521360
return Success((count_tests, generated_tests, function_to_concolic_tests, concolic_test_str))
@@ -1357,15 +1365,12 @@ def generate_optimizations(
13571365
read_only_context_code: str,
13581366
run_experiment: bool = False, # noqa: FBT001, FBT002
13591367
) -> Result[tuple[OptimizationSet, str], str]:
1360-
"""Generate optimization candidates for the function."""
1361-
n_candidates = N_CANDIDATES_EFFECTIVE
1362-
1368+
"""Generate optimization candidates for the function. Backend handles multi-model diversity."""
13631369
future_optimization_candidates = self.executor.submit(
13641370
self.aiservice_client.optimize_python_code,
13651371
read_writable_code.markdown,
13661372
read_only_context_code,
13671373
self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id,
1368-
n_candidates,
13691374
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
13701375
is_async=self.function_to_optimize.is_async,
13711376
)
@@ -1388,7 +1393,6 @@ def generate_optimizations(
13881393
read_writable_code.markdown,
13891394
read_only_context_code,
13901395
self.function_trace_id[:-4] + "EXP1",
1391-
n_candidates,
13921396
ExperimentMetadata(id=self.experiment_id, group="experiment"),
13931397
is_async=self.function_to_optimize.is_async,
13941398
)
@@ -1397,14 +1401,16 @@ def generate_optimizations(
13971401
# Wait for optimization futures to complete
13981402
concurrent.futures.wait(futures)
13991403

1400-
# Retrieve results
1401-
candidates: list[OptimizedCandidate] = future_optimization_candidates.result()
1402-
logger.info(f"!lsp|Generated '{len(candidates)}' candidate optimizations.")
1404+
# Retrieve results - optimize_python_code returns list of candidates
1405+
candidates = future_optimization_candidates.result()
14031406

14041407
if not candidates:
14051408
return Failure(f"/!\\ NO OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}")
14061409

1407-
candidates_experiment = future_candidates_exp.result() if future_candidates_exp else None
1410+
# Handle experiment results
1411+
candidates_experiment = None
1412+
if future_candidates_exp:
1413+
candidates_experiment = future_candidates_exp.result()
14081414
function_references = future_references.result()
14091415

14101416
return Success((OptimizationSet(control=candidates, experiment=candidates_experiment), function_references))
@@ -2024,6 +2030,7 @@ def run_optimized_candidate(
20242030
return self.get_results_not_matched_error()
20252031

20262032
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
2033+
console.rule()
20272034

20282035
# For async functions, instrument at definition site for performance benchmarking
20292036
if self.function_to_optimize.is_async:

0 commit comments

Comments
 (0)