Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,51 @@ def get_new_explanation( # noqa: D417
console.rule()
return ""

def generate_ranking( # noqa: D417
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[float]
) -> list[int] | None:
"""Optimize the given python code for performance by making a request to the Django endpoint.

Parameters
----------
- trace_id : unique uuid of function
- diffs : list of unified diff strings of opt candidates
- speedups : list of speedups of opt candidates

Returns
-------
- List[int]: Ranking of opt candidates in decreasing order

"""
payload = {
"trace_id": trace_id,
"diffs": diffs,
"speedups": speedups,
"optimization_ids": optimization_ids,
"python_version": platform.python_version(),
}
logger.info("Generating ranking")
console.rule()
try:
response = self.make_ai_service_request("/rank", payload=payload, timeout=60)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating ranking: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
return None

if response.status_code == 200:
ranking: list[int] = response.json()["ranking"]
console.rule()
return ranking
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating ranking: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
console.rule()
return None

def log_results( # noqa: D417
self,
function_trace_id: str,
Expand Down
17 changes: 17 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)


def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
"""Return the unified diff between two code strings as a single string.

:param code1: First code string (original).
:param code2: Second code string (modified).
:param fromfile: Label for the first code string.
:param tofile: Label for the second code string.
:return: Unified diff as a string.
"""
code1_lines = code1.splitlines(keepends=True)
code2_lines = code2.splitlines(keepends=True)

diff = difflib.unified_diff(code1_lines, code2_lines, fromfile=fromfile, tofile=tofile, lineterm="")

return "".join(diff)


def diff_length(a: str, b: str) -> int:
"""Compute the length (in characters) of the unified diff between two strings.

Expand Down
10 changes: 4 additions & 6 deletions codeflash/code_utils/version_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import requests
from packaging import version

from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.console import logger
from codeflash.version import __version__

# Simple cache to avoid checking too frequently
_version_cache = {"version": '0.0.0', "timestamp": float(0)}
_version_cache = {"version": "0.0.0", "timestamp": float(0)}
_cache_duration = 3600 # 1 hour cache


Expand Down Expand Up @@ -69,10 +69,8 @@ def check_for_newer_minor_version() -> None:

# Check if there's a newer minor version available
# We only notify for minor version updates, not patch updates
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
logger.warning(
f"A newer version({latest_version}) of Codeflash is available, please update soon!"
)
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
logger.warning(f"A newer version({latest_version}) of Codeflash is available, please update soon!")

except version.InvalidVersion as e:
logger.debug(f"Invalid version format: {e}")
Expand Down
52 changes: 43 additions & 9 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
has_any_async_functions,
module_name_from_file_path,
restore_conftest,
unified_diff_strings,
)
from codeflash.code_utils.config_consts import (
INDIVIDUAL_TESTCASE_TIMEOUT,
Expand Down Expand Up @@ -171,9 +172,10 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
self.candidate_queue.put(candidate)

self.candidate_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
)
if len(refinement_response) > 0:
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
)
self.refinement_done = True

return self.get_next_candidate()
Expand Down Expand Up @@ -537,7 +539,9 @@ def determine_best_candidate(
].markdown
optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
if new_diff_len < ast_code_to_id[normalized_code]["diff_len"]:
if (
new_diff_len < ast_code_to_id[normalized_code]["diff_len"]
): # new candidate has a shorter diff than the previously encountered one
ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
continue
Expand Down Expand Up @@ -660,6 +664,9 @@ def determine_best_candidate(
# reassign the shorter code here
valid_candidates_with_shorter_code = []
diff_lens_list = [] # character level diff
speedups_list = []
optimization_ids = []
diff_strs = []
runtimes_list = []
for valid_opt in valid_optimizations:
valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.flat.strip()))
Expand All @@ -683,12 +690,39 @@ def determine_best_candidate(
diff_lens_list.append(
diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat)
) # char level diff
diff_strs.append(
unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat)
)
speedups_list.append(
1
+ performance_gain(
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=new_best_opt.runtime
)
)
optimization_ids.append(new_best_opt.candidate.optimization_id)
runtimes_list.append(new_best_opt.runtime)
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
# TODO: better way to resolve conflicts with same min ranking
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
min_key = min(overall_ranking, key=overall_ranking.get)
if len(optimization_ids) > 1:
future_ranking = self.executor.submit(
ai_service_client.generate_ranking,
diffs=diff_strs,
optimization_ids=optimization_ids,
speedups=speedups_list,
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
)
concurrent.futures.wait([future_ranking])
ranking = future_ranking.result()
if ranking:
min_key = ranking[0]
else:
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
# TODO: better way to resolve conflicts with same min ranking
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking}
min_key = min(overall_ranking, key=overall_ranking.get)
elif len(optimization_ids) == 1:
min_key = 0 # only one candidate in valid _opts, already returns if there are no valid candidates
else: # 0? shouldn't happen but it's there to escape potential bugs
return None
best_optimization = valid_candidates_with_shorter_code[min_key]
# reassign code string which is the shortest
ai_service_client.log_results(
Expand Down
55 changes: 27 additions & 28 deletions tests/test_version_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,87 +120,86 @@ def test_get_latest_version_from_pypi_cache_expiry(self, mock_get):
self.assertEqual(mock_get.call_count, 2)

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_newer_available(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_newer_available(self, mock_logger,mock_get_version):
"""Test warning message when newer minor version is available."""
mock_get_version.return_value = "1.1.0"

check_for_newer_minor_version()

mock_console.print.assert_called_once()
call_args = mock_console.print.call_args[0][0]
self.assertIn("ℹ️ A newer version of Codeflash is available!", call_args)
self.assertIn("Current version: 1.0.0", call_args)
self.assertIn("Latest version: 1.1.0", call_args)
mock_logger.warning.assert_called_once()
call_args = mock_logger.warning.call_args[0][0]
self.assertIn("of Codeflash is available, please update soon!", call_args)
self.assertIn("1.1.0", call_args)

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_newer_major_available(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_newer_major_available(self, mock_logger,mock_get_version):
"""Test warning message when newer major version is available."""
mock_get_version.return_value = "2.0.0"

check_for_newer_minor_version()

mock_console.print.assert_called_once()
call_args = mock_console.print.call_args[0][0]
self.assertIn("ℹ️ A newer version of Codeflash is available!", call_args)
mock_logger.warning.assert_called_once()
call_args = mock_logger.warning.call_args[0][0]
self.assertIn("of Codeflash is available, please update soon!", call_args)

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.1.0')
def test_check_for_newer_minor_version_no_newer_available(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_no_newer_available(self, mock_logger,mock_get_version):
"""Test no warning when no newer version is available."""
mock_get_version.return_value = "1.0.0"

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_patch_update_ignored(self, mock_console, mock_get_version):
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.1')
def test_check_for_newer_minor_version_patch_update_ignored(self, mock_logger,mock_get_version):
"""Test that patch updates don't trigger warnings."""
mock_get_version.return_value = "1.0.1"

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_same_version(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_same_version(self, mock_logger,mock_get_version):
"""Test no warning when versions are the same."""
mock_get_version.return_value = "1.0.0"

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_no_latest_version(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_no_latest_version(self, mock_logger,mock_get_version):
"""Test no warning when latest version cannot be fetched."""
mock_get_version.return_value = None

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()

@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
@patch('codeflash.code_utils.version_check.console')
@patch('codeflash.code_utils.version_check.logger')
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
def test_check_for_newer_minor_version_invalid_version_format(self, mock_console, mock_get_version):
def test_check_for_newer_minor_version_invalid_version_format(self, mock_logger,mock_get_version):
"""Test handling of invalid version format."""
mock_get_version.return_value = "invalid-version"

check_for_newer_minor_version()

mock_console.print.assert_not_called()
mock_logger.warning.assert_not_called()



Expand Down
Loading