Skip to content

Commit 726405b

Browse files
optimization source
1 parent b4474f3 commit 726405b

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

codeflash/api/aiservice.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from codeflash.code_utils.time_utils import humanize_runtime
1919
from codeflash.lsp.helpers import is_LSP_enabled
2020
from codeflash.models.ExperimentMetadata import ExperimentMetadata
21-
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
21+
from codeflash.models.models import (
22+
AIServiceRefinerRequest,
23+
CodeStringsMarkdown,
24+
OptimizedCandidate,
25+
OptimizedCandidateSource,
26+
)
2227
from codeflash.telemetry.posthog_cf import ph
2328
from codeflash.version import __version__ as codeflash_version
2429

@@ -86,15 +91,20 @@ def make_ai_service_request(
8691
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8792
return response
8893

89-
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
94+
def _get_valid_candidates(
95+
self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource
96+
) -> list[OptimizedCandidate]:
9097
candidates: list[OptimizedCandidate] = []
9198
for opt in optimizations_json:
9299
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
93100
if not code.code_strings:
94101
continue
95102
candidates.append(
96103
OptimizedCandidate(
97-
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
104+
source_code=code,
105+
explanation=opt["explanation"],
106+
optimization_id=opt["optimization_id"],
107+
source=source,
98108
)
99109
)
100110
return candidates
@@ -157,7 +167,7 @@ def optimize_python_code( # noqa: D417
157167
console.rule()
158168
end_time = time.perf_counter()
159169
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")
160-
return self._get_valid_candidates(optimizations_json)
170+
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE)
161171
try:
162172
error = response.json()["error"]
163173
except Exception:
@@ -222,7 +232,7 @@ def optimize_python_code_line_profiler( # noqa: D417
222232
f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information."
223233
)
224234
console.rule()
225-
return self._get_valid_candidates(optimizations_json)
235+
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP)
226236
try:
227237
error = response.json()["error"]
228238
except Exception:
@@ -275,15 +285,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
275285
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
276286
console.rule()
277287

278-
refinements = self._get_valid_candidates(refined_optimizations)
279-
return [
280-
OptimizedCandidate(
281-
source_code=c.source_code,
282-
explanation=c.explanation,
283-
optimization_id=c.optimization_id[:-4] + "refi",
284-
)
285-
for c in refinements
286-
]
288+
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE)
287289

288290
try:
289291
error = response.json()["error"]
@@ -324,7 +326,7 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op
324326
fixed_optimization = response.json()
325327
console.rule()
326328

327-
valid_candidates = self._get_valid_candidates([fixed_optimization])
329+
valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.REPAIR)
328330
if not valid_candidates:
329331
logger.error("Code repair failed to generate a valid candidate.")
330332
return None

codeflash/models/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,19 @@ class TestsInFile:
383383
test_type: TestType
384384

385385

386+
class OptimizedCandidateSource(enum.Enum, str):
387+
OPTIMIZE = "OPTIMIZE"
388+
OPTIMIZE_LP = "OPTIMIZE_LP"
389+
REFINE = "REFINE"
390+
REPAIR = "REPAIR"
391+
392+
386393
@dataclass(frozen=True)
387394
class OptimizedCandidate:
388395
source_code: CodeStringsMarkdown
389396
explanation: str
390397
optimization_id: str
398+
source: OptimizedCandidateSource
391399

392400

393401
@dataclass(frozen=True)

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
OptimizationSet,
7878
OptimizedCandidate,
7979
OptimizedCandidateResult,
80+
OptimizedCandidateSource,
8081
OriginalCodeBaseline,
8182
TestFile,
8283
TestFiles,
@@ -739,7 +740,7 @@ def determine_best_candidate(
739740
)
740741
valid_optimizations.append(best_optimization)
741742
# queue corresponding refined optimization for best optimization
742-
if not candidate.optimization_id.endswith("refi"):
743+
if candidate.source != OptimizedCandidateSource.REFINE:
743744
self.future_all_refinements.append(
744745
self.refine_optimizations(
745746
valid_optimizations=[best_optimization],
@@ -1908,14 +1909,10 @@ def run_optimized_candidate(
19081909
console.rule()
19091910
else:
19101911
result_unmatched_perc = len(diffs) / len(candidate_behavior_results)
1911-
if result_unmatched_perc > 0.5:
1912+
if candidate.source == OptimizedCandidateSource.REPAIR or result_unmatched_perc > 0.5:
19121913
# if the test unmatched percentage is greater than 50%, we can't fix it
19131914
return self.get_results_not_matched_error()
19141915

1915-
if candidate.optimization_id.endswith("cdrp"):
1916-
# prevent looping for now
1917-
return self.get_results_not_matched_error()
1918-
19191916
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
19201917
logger.info("Adding this to the repair queue")
19211918
self.future_all_code_repair.append(

0 commit comments

Comments
 (0)