Skip to content

Commit 999ee9f

Browse files
Merge branch 'main' into pytest-looseing
2 parents 3649e9b + c29d5fc commit 999ee9f

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
242242
# request for new optimizations but don't block execution, check for completion later
243243
# adding to control and experiment set but with same traceid
244244
best_optimization = None
245-
246-
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
245+
for _u, (candidates, exp_type) in enumerate(zip([optimizations_set.control, optimizations_set.experiment],["EXP0","EXP1"])):
247246
if candidates is None:
248247
continue
249248

@@ -253,8 +252,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
253252
original_code_baseline=original_code_baseline,
254253
original_helper_code=original_helper_code,
255254
file_path_to_helper_classes=file_path_to_helper_classes,
255+
exp_type=exp_type,
256256
)
257-
ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id})
257+
ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id})
258258

259259
generated_tests = remove_functions_from_generated_tests(
260260
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
@@ -286,7 +286,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
286286
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
287287
)
288288

289-
self.log_successful_optimization(explanation, generated_tests)
289+
self.log_successful_optimization(explanation, generated_tests, exp_type)
290290

291291
self.replace_function_and_helpers_with_optimized_code(
292292
code_context=code_context, optimized_code=best_optimization.candidate.source_code
@@ -324,7 +324,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
324324
explanation=explanation,
325325
existing_tests_source=existing_tests,
326326
generated_original_test_source=generated_tests_str,
327-
function_trace_id=self.function_trace_id,
327+
function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
328328
coverage_message=coverage_message,
329329
git_remote=self.args.git_remote,
330330
)
@@ -361,6 +361,7 @@ def determine_best_candidate(
361361
original_code_baseline: OriginalCodeBaseline,
362362
original_helper_code: dict[Path, str],
363363
file_path_to_helper_classes: dict[Path, set[str]],
364+
exp_type: str,
364365
) -> BestOptimization | None:
365366
best_optimization: BestOptimization | None = None
366367
best_runtime_until_now = original_code_baseline.runtime
@@ -377,27 +378,26 @@ def determine_best_candidate(
377378
candidates = deque(candidates)
378379
# Start a new thread for AI service request, start loop in main thread
379380
# check if aiservice request is complete, when it is complete, append result to the candidates list
380-
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
381+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
382+
ai_service_client = self.aiservice_client if exp_type=="EXP0" else self.local_aiservice_client
381383
future_line_profile_results = executor.submit(
382-
self.aiservice_client.optimize_python_code_line_profiler,
384+
ai_service_client.optimize_python_code_line_profiler,
383385
source_code=code_context.read_writable_code,
384386
dependency_code=code_context.read_only_context_code,
385-
trace_id=self.function_trace_id,
387+
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
386388
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
387389
num_candidates=10,
388-
experiment_metadata=None,
390+
experiment_metadata=ExperimentMetadata(id=self.experiment_id, group= "control" if exp_type == "EXP0" else "experiment") if self.experiment_id else None,
389391
)
390392
try:
391393
candidate_index = 0
392-
done = False
393394
original_len = len(candidates)
394395
while candidates:
395-
# for candidate_index, candidate in enumerate(candidates, start=1):
396396
done = True if future_line_profile_results is None else future_line_profile_results.done()
397397
if done and (future_line_profile_results is not None):
398398
line_profile_results = future_line_profile_results.result()
399399
candidates.extend(line_profile_results)
400-
original_len += len(candidates)
400+
original_len += len(line_profile_results)
401401
logger.info(
402402
f"Added results from line profiler to candidates, total candidates now: {original_len}"
403403
)
@@ -519,16 +519,16 @@ def determine_best_candidate(
519519
logger.exception(f"Optimization interrupted: {e}")
520520
raise
521521

522-
self.aiservice_client.log_results(
523-
function_trace_id=self.function_trace_id,
522+
ai_service_client.log_results(
523+
function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
524524
speedup_ratio=speedup_ratios,
525525
original_runtime=original_code_baseline.runtime,
526526
optimized_runtime=optimized_runtimes,
527527
is_correct=is_correct,
528528
)
529529
return best_optimization
530530

531-
def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList) -> None:
531+
def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str) -> None:
532532
explanation_panel = Panel(
533533
f"⚡️ Optimization successful! 📄 {self.function_to_optimize.qualified_name} in {explanation.file_path}\n"
534534
f"📈 {explanation.perf_improvement_line}\n"
@@ -555,7 +555,7 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests:
555555
ph(
556556
"cli-optimize-success",
557557
{
558-
"function_trace_id": self.function_trace_id,
558+
"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
559559
"speedup_x": explanation.speedup_x,
560560
"speedup_pct": explanation.speedup_pct,
561561
"best_runtime": explanation.best_runtime_ns,

0 commit comments

Comments
 (0)