Skip to content

Commit b205516

Browse files
committed
ranker wip
1 parent 8753e54 commit b205516

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

codeflash/api/aiservice.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,59 @@ def get_new_explanation( # noqa: D417
353353
console.rule()
354354
return ""
355355

356+
def generate_ranking( # noqa: D417
357+
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[int]
358+
) -> list[int] | None:
359+
"""Optimize the given python code for performance by making a request to the Django endpoint.
360+
361+
Parameters
362+
----------
363+
- source_code (str): The python code to optimize.
364+
- optimized_code (str): The python code generated by the AI service.
365+
- dependency_code (str): The dependency code used as read-only context for the optimization
366+
- original_line_profiler_results: str - line profiler results for the baseline code
367+
- optimized_line_profiler_results: str - line profiler results for the optimized code
368+
- original_code_runtime: str - runtime for the baseline code
369+
- optimized_code_runtime: str - runtime for the optimized code
370+
- speedup: str - speedup of the optimized code
371+
- annotated_tests: str - test functions annotated with runtime
372+
- optimization_id: str - unique id of opt candidate
373+
- original_explanation: str - original_explanation generated for the opt candidate
374+
375+
Returns
376+
-------
377+
- List[OptimizationCandidate]: A list of Optimization Candidates.
378+
379+
"""
380+
payload = {
381+
"trace_id": trace_id,
382+
"diffs": diffs,
383+
"speedups": speedups,
384+
"optimization_ids": optimization_ids,
385+
"python_version": platform.python_version(),
386+
}
387+
logger.info("Generating ranking")
388+
console.rule()
389+
try:
390+
response = self.make_ai_service_request("/ranker", payload=payload, timeout=60)
391+
except requests.exceptions.RequestException as e:
392+
logger.exception(f"Error generating ranking: {e}")
393+
ph("cli-optimize-error-caught", {"error": str(e)})
394+
return None
395+
396+
if response.status_code == 200:
397+
ranking: list[int] = response.json()["ranking"]
398+
console.rule()
399+
return ranking
400+
try:
401+
error = response.json()["error"]
402+
except Exception:
403+
error = response.text
404+
logger.error(f"Error generating ranking: {response.status_code} - {error}")
405+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
406+
console.rule()
407+
return None
408+
356409
def log_results( # noqa: D417
357410
self,
358411
function_trace_id: str,

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,14 @@ def determine_best_candidate(
656656
if not valid_optimizations:
657657
return None
658658
# need to figure out the best candidate here before we return best_optimization
659+
ranking = self.executor.submit(
660+
ai_service_client.generate_ranking,
661+
diffs=[],
662+
optimization_ids=[],
663+
speedups=[],
664+
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
665+
)
666+
print(ranking)
659667
# reassign the shorter code here
660668
valid_candidates_with_shorter_code = []
661669
diff_lens_list = [] # character level diff

0 commit comments

Comments
 (0)