Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,9 @@ def generate_regression_tests(
- Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.

"""
assert test_framework in [
"pytest",
"unittest",
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
assert test_framework in ["pytest", "unittest"], (
f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
)
payload = {
"source_code_being_tested": source_code_being_tested,
"function_to_optimize": function_to_optimize,
Expand Down
10 changes: 4 additions & 6 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyMa
return DependencyManager.POETRY

# Check for uv
if any(key.startswith("uv") for key in tool_section.keys()):
if any(key.startswith("uv") for key in tool_section):
return DependencyManager.UV

# Look for pip-specific markers
Expand Down Expand Up @@ -555,9 +555,7 @@ def customize_codeflash_yaml_content(

# Add codeflash command
codeflash_cmd = get_codeflash_github_action_command(dep_manager)
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)

return optimize_yml_content
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)


# Create or update the pyproject.toml file with the Codeflash dependency & configuration
Expand Down Expand Up @@ -596,8 +594,8 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
formatter_cmds.append("disabled")
if formatter in ["black", "ruff"]:
try:
result = subprocess.run([formatter], capture_output=True, check=False)
except FileNotFoundError as e:
subprocess.run([formatter], capture_output=True, check=False)
except FileNotFoundError:
click.echo(f"⚠️ Formatter not found: {formatter}, please ensure it is installed")
codeflash_section["formatter-cmds"] = formatter_cmds
# Add the 'codeflash' section, ensuring 'tool' section exists
Expand Down
4 changes: 3 additions & 1 deletion codeflash/cli_cmds/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from contextlib import contextmanager
from itertools import cycle
from typing import TYPE_CHECKING, Generator
from typing import TYPE_CHECKING

from rich.console import Console
from rich.logging import RichHandler
Expand All @@ -13,6 +13,8 @@
from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT

if TYPE_CHECKING:
from collections.abc import Generator

from rich.progress import TaskID

DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
Expand Down
3 changes: 1 addition & 2 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | Non
or (functions_to_optimize[0].parents and functions_to_optimize[0].parents[0].type != "ClassDef")
or (
len(functions_to_optimize[0].parents) > 1
or (len(functions_to_optimize) > 1)
and len({fn.parents[0] for fn in functions_to_optimize}) != 1
or ((len(functions_to_optimize) > 1) and len({fn.parents[0] for fn in functions_to_optimize}) != 1)
)
):
return None, set()
Expand Down
9 changes: 4 additions & 5 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import ast
import re
from collections import defaultdict
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, TypeVar
Expand Down Expand Up @@ -91,10 +90,10 @@ def leave_ClassDef(self, node: cst.ClassDef) -> None:
class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = None,
new_functions: list[cst.FunctionDef] = None,
new_class_functions: dict[str, list[cst.FunctionDef]] = None,
modified_init_functions: dict[str, cst.FunctionDef] = None,
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
new_functions: Optional[list[cst.FunctionDef]] = None,
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
) -> None:
super().__init__()
self.modified_functions = modified_functions if modified_functions is not None else {}
Expand Down
9 changes: 5 additions & 4 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
if not full_qualified_name:
raise ValueError("full_qualified_name cannot be empty")
msg = "full_qualified_name cannot be empty"
raise ValueError(msg)
if not full_qualified_name.startswith(module_name):
msg = f"{full_qualified_name} does not start with {module_name}"
raise ValueError(msg)
Expand Down Expand Up @@ -46,9 +47,9 @@ def file_name_from_test_module_name(test_module_name: str, base_dir: Path) -> Pa
def get_imports_from_file(
file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None
) -> list[ast.Import | ast.ImportFrom]:
assert (
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
), "Must provide exactly one of file_path, file_string, or file_ast"
assert sum([file_path is not None, file_string is not None, file_ast is not None]) == 1, (
"Must provide exactly one of file_path, file_string, or file_ast"
)
if file_path:
with file_path.open(encoding="utf8") as file:
file_string = file.read()
Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()

IS_POSIX = os.name != "nt"
IS_POSIX = os.name != "nt"
2 changes: 1 addition & 1 deletion codeflash/code_utils/concolic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _split_top_level_args(self, args_str: str) -> list[str]:

return result

def __init__(self):
def __init__(self) -> None:
# Pre-compiling regular expressions for faster execution
self.assert_re = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$")
self.unittest_re = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$")
Expand Down
7 changes: 3 additions & 4 deletions codeflash/code_utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ def parse_config_file(config_file_path: Path | None = None) -> tuple[dict[str, A
else: # Default to empty list
config[key] = []

assert config["test-framework"] in [
"pytest",
"unittest",
], "In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
assert config["test-framework"] in ["pytest", "unittest"], (
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
)
if len(config["formatter-cmds"]) > 0:
assert config["formatter-cmds"][0] != "your-formatter $file", (
"The formatter command is not set correctly in pyproject.toml. Please set the "
Expand Down
1 change: 0 additions & 1 deletion codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import shlex
import subprocess
import sys
from typing import TYPE_CHECKING

import isort
Expand Down
10 changes: 5 additions & 5 deletions codeflash/code_utils/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def humanize_runtime(time_in_ns: int) -> str:

units = re.split(r",|\s", runtime_human)[1]

if units == "microseconds" or units == "microsecond":
runtime_human = "%.3g" % time_micro
elif units == "milliseconds" or units == "millisecond":
if units in ("microseconds", "microsecond"):
runtime_human = f"{time_micro:.3g}"
elif units in ("milliseconds", "millisecond"):
runtime_human = "%.3g" % (time_micro / 1000)
elif units == "seconds" or units == "second":
elif units in ("seconds", "second"):
runtime_human = "%.3g" % (time_micro / (1000**2))
elif units == "minutes" or units == "minute":
elif units in ("minutes", "minute"):
runtime_human = "%.3g" % (time_micro / (60 * 1000**2))
else: # hours
runtime_human = "%.3g" % (time_micro / (3600 * 1000**2))
Expand Down
1 change: 0 additions & 1 deletion codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def get_code_optimization_context(
read_only_context_code=read_only_code_markdown.markdown,
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,

)

logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
Expand Down
3 changes: 1 addition & 2 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def discover_tests_pytest(
],
cwd=project_root,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
capture_output=True,
text=True,
)
try:
Expand Down
1 change: 0 additions & 1 deletion codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import ast
import json
import os
import random
import warnings
Expand Down
8 changes: 5 additions & 3 deletions codeflash/discovery/pytest_new_process_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,19 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s


if __name__ == "__main__":
from pathlib import Path

import pytest

try:
exitcode = pytest.main(
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
)
except Exception as e:
print(f"Failed to collect tests: {e!s}")
except Exception as e: # noqa: BLE001
print(f"Failed to collect tests: {e!s}") # noqa: T201
exitcode = -1
tests = parse_pytest_collection_results(collected_tests)
import pickle

with open(pickle_path, "wb") as f:
with Path(pickle_path).open("wb") as f:
pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)
3 changes: 1 addition & 2 deletions codeflash/github/PrComment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class PrComment:
winning_benchmarking_test_results: TestResults

def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]:

report_table = {
test_type.to_name(): result
for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items()
Expand All @@ -36,7 +35,7 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]:
"speedup_x": self.speedup_x,
"speedup_pct": self.speedup_pct,
"loop_count": self.winning_benchmarking_test_results.number_of_loops(),
"report_table": report_table
"report_table": report_table,
}


Expand Down
20 changes: 7 additions & 13 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@
if TYPE_CHECKING:
from argparse import Namespace

import numpy as np
import numpy.typing as npt

from codeflash.either import Result
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
from codeflash.verification.verification_utils import TestConfig
Expand Down Expand Up @@ -246,7 +243,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:

best_optimization = None

for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
if candidates is None:
continue

Expand Down Expand Up @@ -855,9 +852,7 @@ def establish_original_code_baseline(
)
console.rule()
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
if not coverage_critic(
coverage_results, self.args.test_framework
):
if not coverage_critic(coverage_results, self.args.test_framework):
return Failure("The threshold for test coverage was not met.")
if test_framework == "pytest":
benchmarking_results, _ = self.run_and_parse_tests(
Expand Down Expand Up @@ -898,7 +893,6 @@ def establish_original_code_baseline(
)
console.rule()


total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
functions_to_remove = [
result.id.test_function_name
Expand Down Expand Up @@ -1094,16 +1088,17 @@ def run_and_parse_tests(
test_framework=self.test_cfg.test_framework,
)
else:
raise ValueError(f"Unexpected testing type: {testing_type}")
msg = f"Unexpected testing type: {testing_type}"
raise ValueError(msg)
except subprocess.TimeoutExpired:
logger.exception(
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error'
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error"
)
return TestResults(), None
if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR:
logger.debug(
f'Nonzero return code {run_result.returncode} when running tests in '
f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n'
f"Nonzero return code {run_result.returncode} when running tests in "
f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n"
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n"
)
Expand Down Expand Up @@ -1149,4 +1144,3 @@ def generate_and_instrument_tests(
zip(generated_test_paths, generated_perf_test_paths)
)
]

9 changes: 5 additions & 4 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
from codeflash.either import is_successful
from codeflash.models.models import TestFiles, ValidCode
from codeflash.models.models import ValidCode
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.test_results import TestType
Expand Down Expand Up @@ -60,7 +60,6 @@ def create_function_optimizer(
function_to_optimize_ast=function_to_optimize_ast,
aiservice_client=self.aiservice_client,
args=self.args,

)

def run(self) -> None:
Expand Down Expand Up @@ -162,7 +161,10 @@ def run(self) -> None:
continue

function_optimizer = self.create_function_optimizer(
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code
function_to_optimize,
function_to_optimize_ast,
function_to_tests,
validated_original_code[original_module_path].source_code,
)
best_optimization = function_optimizer.optimize_function()
if is_successful(best_optimization):
Expand Down Expand Up @@ -192,7 +194,6 @@ def run(self) -> None:
get_run_tmp_file.tmpdir.cleanup()



def run_with_args(args: Namespace) -> None:
optimizer = Optimizer(args)
optimizer.run()
4 changes: 1 addition & 3 deletions codeflash/result/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def to_console_string(self) -> str:
original_runtime_human = humanize_runtime(self.original_runtime_ns)
best_runtime_human = humanize_runtime(self.best_runtime_ns)

explanation = (
return (
f"Optimized {self.function_name} in {self.file_path}\n"
f"{self.perf_improvement_line}\n"
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
Expand All @@ -49,7 +49,5 @@ def to_console_string(self) -> str:
+ f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n"
)

return explanation

def explanation_message(self) -> str:
return self.raw_explanation_message
2 changes: 1 addition & 1 deletion codeflash/telemetry/sentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sentry_sdk.integrations.logging import LoggingIntegration


def init_sentry(enabled: bool = False, exclude_errors: bool = False):
def init_sentry(enabled: bool = False, exclude_errors: bool = False) -> None:
if enabled:
sentry_logging = LoggingIntegration(
level=logging.INFO, # Capture info and above as breadcrumbs
Expand Down
Loading
Loading