Skip to content
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,5 @@ fabric.properties

# Mac
.DS_Store

scratch/
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys


def lol():
print( "lol" )









class BubbleSorter:
def __init__(self, x=0):
self.x = x

def lol(self):
print( "lol" )








def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print("stderr test", file=sys.stderr)
return arr
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def lol():
print( "lol" )







def sorter(arr):
print("codeflash stdout: Sorting list")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print(f"result: {arr}")
return arr
7 changes: 4 additions & 3 deletions codeflash/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
solved problem, please reach out to us at [email protected]. We're hiring!
"""

import os
from pathlib import Path

from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
Expand All @@ -20,12 +21,12 @@ def main() -> None:
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
)
args = parse_args()

if args.command:
if args.config_file and Path.exists(args.config_file):
disable_telemetry = os.environ.get("CODEFLASH_DISABLE_TELEMETRY", "").lower() in {"true", "t", "1", "yes", "y"}
if (not disable_telemetry) and args.config_file and Path.exists(args.config_file):
pyproject_config, _ = parse_config_file(args.config_file)
disable_telemetry = pyproject_config.get("disable_telemetry", False)
else:
disable_telemetry = False
init_sentry(not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not disable_telemetry)
args.func()
Expand Down
78 changes: 54 additions & 24 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import ast
import concurrent.futures
import dataclasses
import os
import shutil
import subprocess
import tempfile
import time
import uuid
from collections import defaultdict, deque
Expand Down Expand Up @@ -124,6 +126,7 @@ def __init__(
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
self.optimizer_temp_dir = Path(tempfile.mkdtemp(prefix="codeflash_opt_fmt_"))

def optimize_function(self) -> Result[BestOptimization, str]:
should_run_experiment = self.experiment_id is not None
Expand Down Expand Up @@ -301,9 +304,30 @@ def optimize_function(self) -> Result[BestOptimization, str]:
code_context=code_context, optimized_code=best_optimization.candidate.source_code
)

new_code, new_helper_code = self.reformat_code_and_helpers(
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code
)
if not self.args.disable_imports_sorting:
main_file_path = self.function_to_optimize.file_path
if main_file_path.exists():
current_main_content = main_file_path.read_text(encoding="utf8")
sorted_main_content = sort_imports(current_main_content)
if sorted_main_content != current_main_content:
main_file_path.write_text(sorted_main_content, encoding="utf8")

writable_helper_file_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_file_path in writable_helper_file_paths:
if helper_file_path.exists():
current_helper_content = helper_file_path.read_text(encoding="utf8")
sorted_helper_content = sort_imports(current_helper_content)
if sorted_helper_content != current_helper_content:
helper_file_path.write_text(sorted_helper_content, encoding="utf8")

new_code = self.function_to_optimize.file_path.read_text(encoding="utf8")
new_helper_code: dict[Path, str] = {}
for helper_file_path_key in original_helper_code:
if helper_file_path_key.exists():
new_helper_code[helper_file_path_key] = helper_file_path_key.read_text(encoding="utf8")
else:
logger.warning(f"Helper file {helper_file_path_key} not found after optimization. It will not be included in new_helper_code for PR.")


existing_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
Expand Down Expand Up @@ -405,6 +429,33 @@ def determine_best_candidate(
future_line_profile_results = None
candidate_index += 1
candidate = candidates.popleft()

formatted_candidate_code = candidate.source_code
if self.args.formatter_cmds:
temp_code_file_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(
mode="w",
suffix=".py",
delete=False,
encoding="utf8",
dir=self.optimizer_temp_dir
) as tmp_file:
tmp_file.write(candidate.source_code)
temp_code_file_path = Path(tmp_file.name)

formatted_candidate_code = format_code(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting code is depenedent on the cwd of the code, imports are grouped according to what module they belong to, for the project's own module they are grouped together. This determination of what is the module they belong to is determined by the cwd.
So we should not format code in a temp directory, the results may not be the same. This is btw why your unit tests were failing today

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's true, as this is working.

This is only formatting the function snippet and the helper snippets -- I don't think this has to do with why those unit tests are failing.

formatter_cmds=self.args.formatter_cmds,
path=temp_code_file_path
)
except Exception as e:
logger.error(f"Error during formatting candidate code via temp file: {e}. Using original candidate code.")
finally:
if temp_code_file_path and temp_code_file_path.exists():
temp_code_file_path.unlink(missing_ok=True)

candidate = dataclasses.replace(candidate, source_code=formatted_candidate_code)

get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
Expand Down Expand Up @@ -580,27 +631,6 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
with Path(module_abspath).open("w", encoding="utf8") as f:
f.write(original_helper_code[module_abspath])

def reformat_code_and_helpers(
self, helper_functions: list[FunctionSource], path: Path, original_code: str
) -> tuple[str, dict[Path, str]]:
should_sort_imports = not self.args.disable_imports_sorting
if should_sort_imports and isort.code(original_code) != original_code:
should_sort_imports = False

new_code = format_code(self.args.formatter_cmds, path)
if should_sort_imports:
new_code = sort_imports(new_code)

new_helper_code: dict[Path, str] = {}
helper_functions_paths = {hf.file_path for hf in helper_functions}
for module_abspath in helper_functions_paths:
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
new_helper_code[module_abspath] = formatted_helper_code

return new_code, new_helper_code

def replace_function_and_helpers_with_optimized_code(
self, code_context: CodeOptimizationContext, optimized_code: str
) -> bool:
Expand Down
Loading