Skip to content

Commit ff8215a

Browse files
committed
almost ready for review
1 parent bd320ee commit ff8215a

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

codeflash/api/aiservice.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def log_results( # noqa: D417
360360
is_correct: dict[str, bool] | None,
361361
optimized_line_profiler_results: dict[str, str] | None,
362362
metadata: dict[str, Any] | None,
363+
optimizations_post: dict[str, str] | None = None,
363364
) -> None:
364365
"""Log features to the database.
365366
@@ -372,6 +373,7 @@ def log_results( # noqa: D417
372373
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
373374
- optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id
374375
- metadata: contains the best optimization id
376+
- optimizations_post - dict mapping opt id to code str after postprocessing
375377
376378
"""
377379
payload = {
@@ -383,6 +385,7 @@ def log_results( # noqa: D417
383385
"codeflash_version": codeflash_version,
384386
"optimized_line_profiler_results": optimized_line_profiler_results,
385387
"metadata": metadata,
388+
"optimizations_post": optimizations_post,
386389
}
387390
try:
388391
self.make_ai_service_request("/log_features", payload=payload, timeout=5)

codeflash/optimization/function_optimizer.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,12 @@
9494

9595
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
9696
from codeflash.either import Result
97-
from codeflash.models.models import BenchmarkKey, CoverageData, FunctionCalledInTest, FunctionSource
9897
from codeflash.models.models import (
9998
BenchmarkKey,
10099
CodeStringsMarkdown,
101100
CoverageData,
102101
FunctionCalledInTest,
103102
FunctionSource,
104-
OptimizedCandidate,
105103
)
106104
from codeflash.verification.verification_utils import TestConfig
107105

@@ -385,6 +383,7 @@ def determine_best_candidate(
385383
future_all_refinements: list[concurrent.futures.Future] = []
386384
ast_code_to_id = {}
387385
valid_optimizations = []
386+
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
388387
# Start a new thread for AI service request, start loop in main thread
389388
# check if aiservice request is complete, when it is complete, append result to the candidates list
390389
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
@@ -438,17 +437,37 @@ def determine_best_candidate(
438437
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
439438
)
440439
continue
441-
normalized_code = ast.unparse(ast.parse(candidate.source_code.strip()))
440+
# check if this code has been evaluated before by checking the ast normalized code string
441+
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
442442
if normalized_code in ast_code_to_id:
443-
new_diff_len = diff_length(candidate.source_code, code_context.read_writable_code)
443+
# update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes
444+
speedup_ratios[candidate.optimization_id] = speedup_ratios[
445+
ast_code_to_id[normalized_code]["optimization_id"]
446+
]
447+
is_correct[candidate.optimization_id] = is_correct[
448+
ast_code_to_id[normalized_code]["optimization_id"]
449+
]
450+
optimized_runtimes[candidate.optimization_id] = optimized_runtimes[
451+
ast_code_to_id[normalized_code]["optimization_id"]
452+
]
453+
optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[
454+
ast_code_to_id[normalized_code]["optimization_id"]
455+
]
456+
optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][
457+
"shorter_source_code"
458+
].markdown
459+
optimizations_post[ast_code_to_id[normalized_code]["optimization_id"]] = ast_code_to_id[
460+
normalized_code
461+
]["shorter_source_code"].markdown
462+
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
444463
if new_diff_len < ast_code_to_id[normalized_code]["diff_len"]:
445464
ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
446465
ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
447466
continue
448467
ast_code_to_id[normalized_code] = {
449468
"optimization_id": candidate.optimization_id,
450469
"shorter_source_code": candidate.source_code,
451-
"diff_len": diff_length(candidate.source_code, code_context.read_writable_code),
470+
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
452471
}
453472
run_results = self.run_optimized_candidate(
454473
optimization_candidate_index=candidate_index,
@@ -592,7 +611,7 @@ def determine_best_candidate(
592611
diff_lens_list = [] # character level diff
593612
runtimes_list = []
594613
for valid_opt in valid_optimizations:
595-
valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.strip()))
614+
valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.flat.strip()))
596615
new_candidate_with_shorter_code = OptimizedCandidate(
597616
source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"],
598617
optimization_id=valid_opt.candidate.optimization_id,
@@ -628,6 +647,7 @@ def determine_best_candidate(
628647
optimized_runtime=optimized_runtimes,
629648
is_correct=is_correct,
630649
optimized_line_profiler_results=optimized_line_profiler_results,
650+
optimizations_post=optimizations_post,
631651
metadata={"best_optimization_id": best_optimization.candidate.optimization_id},
632652
)
633653
return best_optimization

0 commit comments

Comments
 (0)