Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,59 @@ def get_new_explanation( # noqa: D417
console.rule()
return ""

def generate_ranking( # noqa: D417
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[float]
) -> list[int] | None:
"""Optimize the given python code for performance by making a request to the Django endpoint.

Parameters
----------
- source_code (str): The python code to optimize.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO : change the docstring

- optimized_code (str): The python code generated by the AI service.
- dependency_code (str): The dependency code used as read-only context for the optimization
- original_line_profiler_results: str - line profiler results for the baseline code
- optimized_line_profiler_results: str - line profiler results for the optimized code
- original_code_runtime: str - runtime for the baseline code
- optimized_code_runtime: str - runtime for the optimized code
- speedup: str - speedup of the optimized code
- annotated_tests: str - test functions annotated with runtime
- optimization_id: str - unique id of opt candidate
- original_explanation: str - original_explanation generated for the opt candidate

Returns
-------
- List[OptimizationCandidate]: A list of Optimization Candidates.

"""
payload = {
"trace_id": trace_id,
"diffs": diffs,
"speedups": speedups,
"optimization_ids": optimization_ids,
"python_version": platform.python_version(),
}
logger.info("Generating ranking")
console.rule()
try:
response = self.make_ai_service_request("/rank", payload=payload, timeout=60)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating ranking: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
return None

if response.status_code == 200:
ranking: list[int] = response.json()["ranking"]
console.rule()
return ranking
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating ranking: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
console.rule()
return None

def log_results( # noqa: D417
self,
function_trace_id: str,
Expand Down
17 changes: 17 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)


def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
"""Return the unified diff between two code strings as a single string.

:param code1: First code string (original).
:param code2: Second code string (modified).
:param fromfile: Label for the first code string.
:param tofile: Label for the second code string.
:return: Unified diff as a string.
"""
code1_lines = code1.splitlines(keepends=True)
code2_lines = code2.splitlines(keepends=True)

diff = difflib.unified_diff(code1_lines, code2_lines, fromfile=fromfile, tofile=tofile, lineterm="")

return "".join(diff)


def diff_length(a: str, b: str) -> int:
"""Compute the length (in characters) of the unified diff between two strings.

Expand Down
10 changes: 4 additions & 6 deletions codeflash/code_utils/version_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import requests
from packaging import version

from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.console import logger
from codeflash.version import __version__

# Simple cache to avoid checking too frequently
_version_cache = {"version": '0.0.0', "timestamp": float(0)}
_version_cache = {"version": "0.0.0", "timestamp": float(0)}
_cache_duration = 3600 # 1 hour cache


Expand Down Expand Up @@ -69,10 +69,8 @@ def check_for_newer_minor_version() -> None:

# Check if there's a newer minor version available
# We only notify for minor version updates, not patch updates
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
logger.warning(
f"A newer version({latest_version}) of Codeflash is available, please update soon!"
)
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
logger.warning(f"A newer version({latest_version}) of Codeflash is available, please update soon!")

except version.InvalidVersion as e:
logger.debug(f"Invalid version format: {e}")
Expand Down
44 changes: 36 additions & 8 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
has_any_async_functions,
module_name_from_file_path,
restore_conftest,
unified_diff_strings,
)
from codeflash.code_utils.config_consts import (
INDIVIDUAL_TESTCASE_TIMEOUT,
Expand Down Expand Up @@ -171,9 +172,10 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
self.candidate_queue.put(candidate)

self.candidate_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
)
if len(refinement_response) > 0:
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
)
self.refinement_done = True

return self.get_next_candidate()
Expand Down Expand Up @@ -660,6 +662,9 @@ def determine_best_candidate(
# reassign the shorter code here
valid_candidates_with_shorter_code = []
diff_lens_list = [] # character level diff
speedups_list = []
optimization_ids = []
diff_strs = []
runtimes_list = []
for valid_opt in valid_optimizations:
valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.flat.strip()))
Expand All @@ -683,12 +688,35 @@ def determine_best_candidate(
diff_lens_list.append(
diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat)
) # char level diff
diff_strs.append(
unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat)
)
speedups_list.append(
1
+ performance_gain(
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=new_best_opt.runtime
)
)
optimization_ids.append(new_best_opt.candidate.optimization_id)
runtimes_list.append(new_best_opt.runtime)
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
# TODO: better way to resolve conflicts with same min ranking
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
min_key = min(overall_ranking, key=overall_ranking.get)
future_ranking = self.executor.submit(
ai_service_client.generate_ranking,
diffs=diff_strs,
optimization_ids=optimization_ids,
speedups=speedups_list,
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
)
concurrent.futures.wait([future_ranking])
ranking = future_ranking.result()
if ranking:
ranking = [x - 1 for x in ranking]
min_key = ranking[0]
else:
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
# TODO: better way to resolve conflicts with same min ranking
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
min_key = min(overall_ranking, key=overall_ranking.get)
best_optimization = valid_candidates_with_shorter_code[min_key]
# reassign code string which is the shortest
ai_service_client.log_results(
Expand Down
55 changes: 27 additions & 28 deletions tests/test_version_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,87 +120,86 @@ def test_get_latest_version_from_pypi_cache_expiry(self, mock_get):
self.assertEqual(mock_get.call_count, 2)

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_newer_available(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_newer_available(self, mock_logger,mock_get_version):
"""Test warning message when newer minor version is available."""
mock_get_version.return_value = "1.1.0"

check_for_newer_minor_version()

mock_console.print.assert_called_once()
call_args = mock_console.print.call_args[0][0]
self.assertIn("ℹ️ A newer version of Codeflash is available!", call_args)
self.assertIn("Current version: 1.0.0", call_args)
self.assertIn("Latest version: 1.1.0", call_args)
mock_logger.warning.assert_called_once()
call_args = mock_logger.warning.call_args[0][0]
self.assertIn("of Codeflash is available, please update soon!", call_args)
self.assertIn("1.1.0", call_args)

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_newer_major_available(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_newer_major_available(self, mock_logger,mock_get_version):
"""Test warning message when newer major version is available."""
mock_get_version.return_value = "2.0.0"

check_for_newer_minor_version()

mock_console.print.assert_called_once()
call_args = mock_console.print.call_args[0][0]
self.assertIn("ℹ️ A newer version of Codeflash is available!", call_args)
mock_logger.warning.assert_called_once()
call_args = mock_logger.warning.call_args[0][0]
self.assertIn("of Codeflash is available, please update soon!", call_args)

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.1.0')
def test_check_for_newer_minor_version_no_newer_available(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_no_newer_available(self, mock_logger,mock_get_version):
"""Test no warning when no newer version is available."""
mock_get_version.return_value = "1.0.0"

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_patch_update_ignored(self, mock_console, mock_get_version):
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.1')
def test_check_for_newer_minor_version_patch_update_ignored(self, mock_logger,mock_get_version):
"""Test that patch updates don't trigger warnings."""
mock_get_version.return_value = "1.0.1"

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_same_version(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_same_version(self, mock_logger,mock_get_version):
"""Test no warning when versions are the same."""
mock_get_version.return_value = "1.0.0"

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_no_latest_version(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_no_latest_version(self, mock_logger,mock_get_version):
"""Test no warning when latest version cannot be fetched."""
mock_get_version.return_value = None

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_invalid_version_format(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_invalid_version_format(self, mock_logger,mock_get_version):
"""Test handling of invalid version format."""
mock_get_version.return_value = "invalid-version"

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()



Expand Down
Loading