diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 180abe105..cacd15f3e 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -196,10 +196,9 @@ def generate_regression_tests( - Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred. """ - assert test_framework in [ - "pytest", - "unittest", - ], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'" + assert test_framework in ["pytest", "unittest"], ( + f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'" + ) payload = { "source_code_being_tested": source_code_being_tested, "function_to_optimize": function_to_optimize, diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 90f58f515..c5e3299a4 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -462,7 +462,7 @@ def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyMa return DependencyManager.POETRY # Check for uv - if any(key.startswith("uv") for key in tool_section.keys()): + if any(key.startswith("uv") for key in tool_section): return DependencyManager.UV # Look for pip-specific markers @@ -555,9 +555,7 @@ def customize_codeflash_yaml_content( # Add codeflash command codeflash_cmd = get_codeflash_github_action_command(dep_manager) - optimize_yml_content = optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd) - - return optimize_yml_content + return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd) # Create or update the pyproject.toml file with the Codeflash dependency & configuration @@ -596,8 +594,8 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: formatter_cmds.append("disabled") if formatter in ["black", "ruff"]: try: - result = subprocess.run([formatter], capture_output=True, check=False) - except FileNotFoundError as e: + subprocess.run([formatter], capture_output=True, check=False) + except FileNotFoundError: click.echo(f"⚠️ Formatter not found: {formatter}, please ensure it is installed") codeflash_section["formatter-cmds"] = formatter_cmds # Add the 'codeflash' section, ensuring 'tool' section exists diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index 9661e9509..3ef0f2eca 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -3,7 +3,7 @@ import logging from contextlib import contextmanager from itertools import cycle -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING from rich.console import Console from rich.logging import RichHandler @@ -13,6 +13,8 @@ from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT if TYPE_CHECKING: + from collections.abc import Generator + from rich.progress import TaskID DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index f9ed39241..409551d0a 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -115,8 +115,7 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | Non or (functions_to_optimize[0].parents and functions_to_optimize[0].parents[0].type != "ClassDef") or ( len(functions_to_optimize[0].parents) > 1 - or (len(functions_to_optimize) > 1) - and len({fn.parents[0] for fn in functions_to_optimize}) != 1 + or ((len(functions_to_optimize) > 1) and len({fn.parents[0] for fn in functions_to_optimize}) != 1) ) ): return None, set() diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 42c9ead9d..2c169136d 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import re from collections import defaultdict from functools import lru_cache from typing import TYPE_CHECKING, Optional, TypeVar @@ -91,10 +90,10 @@ def leave_ClassDef(self, node: cst.ClassDef) -> None: class OptimFunctionReplacer(cst.CSTTransformer): def __init__( self, - modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = None, - new_functions: list[cst.FunctionDef] = None, - new_class_functions: dict[str, list[cst.FunctionDef]] = None, - modified_init_functions: dict[str, cst.FunctionDef] = None, + modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None, + new_functions: Optional[list[cst.FunctionDef]] = None, + new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None, + modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None, ) -> None: super().__init__() self.modified_functions = modified_functions if modified_functions is not None else {} diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index d4456bb62..5e2aac42e 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -12,7 +12,8 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str: if not full_qualified_name: - raise ValueError("full_qualified_name cannot be empty") + msg = "full_qualified_name cannot be empty" + raise ValueError(msg) if not full_qualified_name.startswith(module_name): msg = f"{full_qualified_name} does not start with {module_name}" raise ValueError(msg) @@ -46,9 +47,9 @@ def file_name_from_test_module_name(test_module_name: str, base_dir: Path) -> Pa def get_imports_from_file( file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None ) -> list[ast.Import | ast.ImportFrom]: - assert ( - sum([file_path is not None, file_string is not None, file_ast is not None]) == 1 - ), "Must provide exactly one of file_path, file_string, or file_ast" + assert sum([file_path is not None, file_string is not None, file_ast is not None]) == 1, ( + "Must provide exactly one of file_path, file_string, or file_ast" + ) if file_path: with file_path.open(encoding="utf8") as file: file_string = file.read() diff --git a/codeflash/code_utils/compat.py b/codeflash/code_utils/compat.py index 0dfa0210e..8bdf093bb 100644 --- a/codeflash/code_utils/compat.py +++ b/codeflash/code_utils/compat.py @@ -11,4 +11,4 @@ SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() -IS_POSIX = os.name != "nt" \ No newline at end of file +IS_POSIX = os.name != "nt" diff --git a/codeflash/code_utils/concolic_utils.py b/codeflash/code_utils/concolic_utils.py index bad02f49e..81cc71a6c 100644 --- a/codeflash/code_utils/concolic_utils.py +++ b/codeflash/code_utils/concolic_utils.py @@ -62,7 +62,7 @@ def _split_top_level_args(self, args_str: str) -> list[str]: return result - def __init__(self): + def __init__(self) -> None: # Pre-compiling regular expressions for faster execution self.assert_re = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$") self.unittest_re = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$") diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 89832503b..a03330614 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -81,10 +81,9 @@ def parse_config_file(config_file_path: Path | None = None) -> tuple[dict[str, A else: # Default to empty list config[key] = [] - assert config["test-framework"] in [ - "pytest", - "unittest", - ], "In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest." + assert config["test-framework"] in ["pytest", "unittest"], ( + "In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest." + ) if len(config["formatter-cmds"]) > 0: assert config["formatter-cmds"][0] != "your-formatter $file", ( "The formatter command is not set correctly in pyproject.toml. Please set the " diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 6328948dd..875fd0a1f 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -3,7 +3,6 @@ import os import shlex import subprocess -import sys from typing import TYPE_CHECKING import isort diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index 1a2e87f14..ce97183b4 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -16,13 +16,13 @@ def humanize_runtime(time_in_ns: int) -> str: units = re.split(r",|\s", runtime_human)[1] - if units == "microseconds" or units == "microsecond": - runtime_human = "%.3g" % time_micro - elif units == "milliseconds" or units == "millisecond": + if units in ("microseconds", "microsecond"): + runtime_human = f"{time_micro:.3g}" + elif units in ("milliseconds", "millisecond"): runtime_human = "%.3g" % (time_micro / 1000) - elif units == "seconds" or units == "second": + elif units in ("seconds", "second"): runtime_human = "%.3g" % (time_micro / (1000**2)) - elif units == "minutes" or units == "minute": + elif units in ("minutes", "minute"): runtime_human = "%.3g" % (time_micro / (60 * 1000**2)) else: # hours runtime_human = "%.3g" % (time_micro / (3600 * 1000**2)) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 0a7caab45..b6dab0162 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -70,7 +70,6 @@ def get_code_optimization_context( read_only_context_code=read_only_code_markdown.markdown, helper_functions=helpers_of_fto_obj_list, preexisting_objects=preexisting_objects, - ) logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index a36774f66..ae744ea97 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -59,8 +59,7 @@ def discover_tests_pytest( ], cwd=project_root, check=False, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, text=True, ) try: diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 6e4dd0571..a234a2827 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import json import os import random import warnings diff --git a/codeflash/discovery/pytest_new_process_discovery.py b/codeflash/discovery/pytest_new_process_discovery.py index e85bbdcce..397eabe01 100644 --- a/codeflash/discovery/pytest_new_process_discovery.py +++ b/codeflash/discovery/pytest_new_process_discovery.py @@ -28,17 +28,19 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s if __name__ == "__main__": + from pathlib import Path + import pytest try: exitcode = pytest.main( [tests_root, "-pno:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()] ) - except Exception as e: - print(f"Failed to collect tests: {e!s}") + except Exception as e: # noqa: BLE001 + print(f"Failed to collect tests: {e!s}") # noqa: T201 exitcode = -1 tests = parse_pytest_collection_results(collected_tests) import pickle - with open(pickle_path, "wb") as f: + with Path(pickle_path).open("wb") as f: pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index a6f6aa892..f266a039d 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -20,7 +20,6 @@ class PrComment: winning_benchmarking_test_results: TestResults def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: - report_table = { test_type.to_name(): result for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items() @@ -36,7 +35,7 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: "speedup_x": self.speedup_x, "speedup_pct": self.speedup_pct, "loop_count": self.winning_benchmarking_test_results.number_of_loops(), - "report_table": report_table + "report_table": report_table, } diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7b067a094..6afaef2c5 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -75,9 +75,6 @@ if TYPE_CHECKING: from argparse import Namespace - import numpy as np - import numpy.typing as npt - from codeflash.either import Result from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate from codeflash.verification.verification_utils import TestConfig @@ -246,7 +243,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: best_optimization = None - for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]): + for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]): if candidates is None: continue @@ -855,9 +852,7 @@ def establish_original_code_baseline( ) console.rule() return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.") - if not coverage_critic( - coverage_results, self.args.test_framework - ): + if not coverage_critic(coverage_results, self.args.test_framework): return Failure("The threshold for test coverage was not met.") if test_framework == "pytest": benchmarking_results, _ = self.run_and_parse_tests( @@ -898,7 +893,6 @@ def establish_original_code_baseline( ) console.rule() - total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index functions_to_remove = [ result.id.test_function_name @@ -1094,16 +1088,17 @@ def run_and_parse_tests( test_framework=self.test_cfg.test_framework, ) else: - raise ValueError(f"Unexpected testing type: {testing_type}") + msg = f"Unexpected testing type: {testing_type}" + raise ValueError(msg) except subprocess.TimeoutExpired: logger.exception( - f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error' + f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error" ) return TestResults(), None if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR: logger.debug( - f'Nonzero return code {run_result.returncode} when running tests in ' - f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n' + f"Nonzero return code {run_result.returncode} when running tests in " + f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n" f"stdout: {run_result.stdout}\n" f"stderr: {run_result.stderr}\n" ) @@ -1149,4 +1144,3 @@ def generate_and_instrument_tests( zip(generated_test_paths, generated_perf_test_paths) ) ] - diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index cae78a153..423ce62ce 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -16,7 +16,7 @@ from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.functions_to_optimize import get_functions_to_optimize from codeflash.either import is_successful -from codeflash.models.models import TestFiles, ValidCode +from codeflash.models.models import ValidCode from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.telemetry.posthog_cf import ph from codeflash.verification.test_results import TestType @@ -60,7 +60,6 @@ def create_function_optimizer( function_to_optimize_ast=function_to_optimize_ast, aiservice_client=self.aiservice_client, args=self.args, - ) def run(self) -> None: @@ -162,7 +161,10 @@ def run(self) -> None: continue function_optimizer = self.create_function_optimizer( - function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code + function_to_optimize, + function_to_optimize_ast, + function_to_tests, + validated_original_code[original_module_path].source_code, ) best_optimization = function_optimizer.optimize_function() if is_successful(best_optimization): @@ -192,7 +194,6 @@ def run(self) -> None: get_run_tmp_file.tmpdir.cleanup() - def run_with_args(args: Namespace) -> None: optimizer = Optimizer(args) optimizer.run() diff --git a/codeflash/result/explanation.py b/codeflash/result/explanation.py index 63ae5395a..1dd53ceb5 100644 --- a/codeflash/result/explanation.py +++ b/codeflash/result/explanation.py @@ -38,7 +38,7 @@ def to_console_string(self) -> str: original_runtime_human = humanize_runtime(self.original_runtime_ns) best_runtime_human = humanize_runtime(self.best_runtime_ns) - explanation = ( + return ( f"Optimized {self.function_name} in {self.file_path}\n" f"{self.perf_improvement_line}\n" f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" @@ -49,7 +49,5 @@ def to_console_string(self) -> str: + f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n" ) - return explanation - def explanation_message(self) -> str: return self.raw_explanation_message diff --git a/codeflash/telemetry/sentry.py b/codeflash/telemetry/sentry.py index 75d79778a..81dee2957 100644 --- a/codeflash/telemetry/sentry.py +++ b/codeflash/telemetry/sentry.py @@ -4,7 +4,7 @@ from sentry_sdk.integrations.logging import LoggingIntegration -def init_sentry(enabled: bool = False, exclude_errors: bool = False): +def init_sentry(enabled: bool = False, exclude_errors: bool = False) -> None: if enabled: sentry_logging = LoggingIntegration( level=logging.INFO, # Capture info and above as breadcrumbs diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 39b05e01f..96c0202f1 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -83,7 +83,7 @@ def __init__( self.con = None self.output_file = Path(output).resolve() self.functions = functions - self.function_modules: List[FunctionModules] = [] + self.function_modules: list[FunctionModules] = [] self.function_count = defaultdict(int) self.current_file_path = Path(__file__).resolve() self.ignored_qualified_functions = { @@ -208,14 +208,13 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: overflow="ignore", ) - def tracer_logic(self, frame: FrameType, event: str): + def tracer_logic(self, frame: FrameType, event: str) -> None: if event != "call": return - if self.timeout is not None: - if (time.time() - self.start_time) > self.timeout: - sys.setprofile(None) - console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.") - return + if self.timeout is not None and (time.time() - self.start_time) > self.timeout: + sys.setprofile(None) + console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.") + return code = frame.f_code file_name = Path(code.co_filename).resolve() # TODO : It currently doesn't log the last return call from the first function @@ -224,9 +223,8 @@ def tracer_logic(self, frame: FrameType, event: str): return if not file_name.exists(): return - if self.functions: - if code.co_name not in self.functions: - return + if self.functions and code.co_name not in self.functions: + return class_name = None arguments = frame.f_locals try: @@ -325,10 +323,7 @@ def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None: if event == "c_call": self.c_func_name = arg.__name__ - if self.dispatch[event](self, frame, t): - prof_success = True - else: - prof_success = False + prof_success = bool(self.dispatch[event](self, frame, t)) # tracer section self.tracer_logic(frame, event) # measure the time as the last thing before return @@ -337,7 +332,7 @@ def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None: else: self.t = timer() - t # put back unrecorded delta - def trace_dispatch_call(self, frame, t): + def trace_dispatch_call(self, frame, t) -> int: if self.cur and frame.f_back is not self.cur[-2]: rpt, rit, ret, rfn, rframe, rcur = self.cur if not isinstance(rframe, Tracer.fake_frame): @@ -375,7 +370,7 @@ def trace_dispatch_exception(self, frame, t): self.cur = rpt, rit + t, ret, rfn, rframe, rcur return 1 - def trace_dispatch_c_call(self, frame, t): + def trace_dispatch_c_call(self, frame, t) -> int: fn = ("", 0, self.c_func_name, None) self.cur = (t, 0, 0, fn, frame, self.cur) timings = self.timings @@ -386,7 +381,7 @@ def trace_dispatch_c_call(self, frame, t): timings[fn] = 0, 0, 0, 0, {} return 1 - def trace_dispatch_return(self, frame, t): + def trace_dispatch_return(self, frame, t) -> int: if frame is not self.cur[-2]: assert frame is self.cur[-2].f_back, ("Bad return", self.cur[-3]) self.trace_dispatch_return(self.cur[-2], 0) @@ -433,31 +428,28 @@ def trace_dispatch_return(self, frame, t): } class fake_code: - def __init__(self, filename, line, name): + def __init__(self, filename, line, name) -> None: self.co_filename = filename self.co_line = line self.co_name = name self.co_firstlineno = 0 - def __repr__(self): + def __repr__(self) -> str: return repr((self.co_filename, self.co_line, self.co_name, None)) class fake_frame: - def __init__(self, code, prior): + def __init__(self, code, prior) -> None: self.f_code = code self.f_back = prior self.f_locals = {} - def simulate_call(self, name): + def simulate_call(self, name) -> None: code = self.fake_code("profiler", 0, name) - if self.cur: - pframe = self.cur[-2] - else: - pframe = None + pframe = self.cur[-2] if self.cur else None frame = self.fake_frame(code, pframe) self.dispatch["call"](self, frame, 0) - def simulate_cmd_complete(self): + def simulate_cmd_complete(self) -> None: get_time = self.timer t = get_time() - self.t while self.cur[-1]: @@ -467,7 +459,7 @@ def simulate_cmd_complete(self): t = 0 self.t = get_time() - t - def print_stats(self, sort=-1): + def print_stats(self, sort=-1) -> None: import pstats if not isinstance(sort, tuple): @@ -520,7 +512,7 @@ def print_stats(self, sort=-1): console.print("\n".join(new_stats)) - def make_pstats_compatible(self): + def make_pstats_compatible(self) -> None: # delete the extra class_name item from the function tuple self.files = [] self.top_level = [] @@ -535,18 +527,18 @@ def make_pstats_compatible(self): self.stats = new_stats self.timings = new_timings - def dump_stats(self, file): + def dump_stats(self, file) -> None: with open(file, "wb") as f: self.create_stats() marshal.dump(self.stats, f) - def create_stats(self): + def create_stats(self) -> None: self.simulate_cmd_complete() self.snapshot_stats() - def snapshot_stats(self): + def snapshot_stats(self) -> None: self.stats = {} - for func, (cc, ns, tt, ct, callers) in self.timings.items(): + for func, (cc, _ns, tt, ct, callers) in self.timings.items(): callers = callers.copy() nc = 0 for callcnt in callers.values(): diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index 023c50f55..50b8dae2e 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -8,7 +8,7 @@ class ProfileStats(pstats.Stats): - def __init__(self, trace_file_path: str, time_unit: str = "ns"): + def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None: assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist" assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}" self.trace_file_path = trace_file_path @@ -16,7 +16,7 @@ def __init__(self, trace_file_path: str, time_unit: str = "ns"): logger.debug(hasattr(self, "create_stats")) super().__init__(copy(self)) - def create_stats(self): + def create_stats(self) -> None: self.con = sqlite3.connect(self.trace_file_path) cur = self.con.cursor() pdata = cur.execute("SELECT * FROM pstats").fetchall() @@ -57,7 +57,7 @@ def print_stats(self, *amount): if self.total_calls != self.prim_calls: print("(%d primitive calls)" % self.prim_calls, end=" ", file=self.stream) time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit] - print("in %.3f %s" % (self.total_tt, time_unit), file=self.stream) + print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream) print(file=self.stream) width, list = self.get_print_list(amount) if list: diff --git a/codeflash/tracing/replay_test.py b/codeflash/tracing/replay_test.py index 6150df3cc..62d9dbbe6 100644 --- a/codeflash/tracing/replay_test.py +++ b/codeflash/tracing/replay_test.py @@ -2,7 +2,8 @@ import sqlite3 import textwrap -from typing import Any, Generator, List, Optional +from collections.abc import Generator +from typing import Any, List, Optional from codeflash.discovery.functions_to_optimize import FunctionProperties, inspect_top_level_functions_or_methods from codeflash.tracing.tracing_utils import FunctionModules @@ -30,7 +31,8 @@ def get_next_arg_and_return( if event_type == "call": yield val[7] else: - raise ValueError("Invalid Trace event type") + msg = "Invalid Trace event type" + raise ValueError(msg) def get_function_alias(module: str, function_name: str) -> str: @@ -38,7 +40,7 @@ def get_function_alias(module: str, function_name: str) -> str: def create_trace_replay_test( - trace_file: str, functions: List[FunctionModules], test_framework: str = "pytest", max_run_count=100 + trace_file: str, functions: list[FunctionModules], test_framework: str = "pytest", max_run_count=100 ) -> str: assert test_framework in ["pytest", "unittest"] diff --git a/codeflash/update_license_version.py b/codeflash/update_license_version.py index b5c9989fe..6aad189b4 100644 --- a/codeflash/update_license_version.py +++ b/codeflash/update_license_version.py @@ -1,4 +1,3 @@ -import os import re from datetime import datetime from pathlib import Path @@ -6,7 +5,7 @@ from version import __version_tuple__ -def main(): +def main() -> None: # Use the version tuple from version.py version = __version_tuple__ diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index c8e032ede..14acdbec6 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -7,8 +7,8 @@ from pathlib import Path from codeflash.cli_cmds.console import console, logger -from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE +from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.static_analysis import has_typed_parameters from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.functions_to_optimize import FunctionToOptimize diff --git a/codeflash/verification/pytest_plugin.py b/codeflash/verification/pytest_plugin.py index 99fcd55e8..1080ac3e2 100644 --- a/codeflash/verification/pytest_plugin.py +++ b/codeflash/verification/pytest_plugin.py @@ -97,10 +97,10 @@ def pytest_addoption(parser: Parser) -> None: @pytest.hookimpl(trylast=True) def pytest_configure(config: Config) -> None: config.addinivalue_line("markers", "loops(n): run the given test function `n` times.") - config.pluginmanager.register(PyTest_Loops(config), PyTest_Loops.name) + config.pluginmanager.register(PytestLoops(config), PytestLoops.name) -class PyTest_Loops: +class PytestLoops: name: str = "pytest-loops" def __init__(self, config: Config) -> None: @@ -113,9 +113,8 @@ def __init__(self, config: Config) -> None: def pytest_runtestloop(self, session: Session) -> bool: """Reimplement the test loop but loop for the user defined amount of time.""" if session.testsfailed and not session.config.option.continue_on_collection_errors: - raise session.Interrupted( - "%d error%s during collection" % (session.testsfailed, "s" if session.testsfailed != 1 else "") - ) + msg = "{} error{} during collection".format(session.testsfailed, "s" if session.testsfailed != 1 else "") + raise session.Interrupted(msg) if session.config.option.collectonly: return True @@ -130,11 +129,11 @@ def pytest_runtestloop(self, session: Session) -> bool: total_time = self._get_total_time(session) for index, item in enumerate(session.items): - item: pytest.Item = item - item._report_sections.clear() # clear reports for new test + item: pytest.Item = item # noqa: PLW0127, PLW2901 + item._report_sections.clear() # clear reports for new test # noqa: SLF001 if total_time > SHORTEST_AMOUNT_OF_TIME: - item._nodeid = self._set_nodeid(item._nodeid, count) + item._nodeid = self._set_nodeid(item._nodeid, count) # noqa: SLF001 next_item: pytest.Item = session.items[index + 1] if index + 1 < len(session.items) else None @@ -234,7 +233,8 @@ def _get_total_time(self, session: Session) -> float: seconds = session.config.option.codeflash_seconds total_time = hours_in_seconds + minutes_in_seconds + seconds if total_time < SHORTEST_AMOUNT_OF_TIME: - raise InvalidTimeParameterError(f"Total time cannot be less than: {SHORTEST_AMOUNT_OF_TIME}!") + msg = f"Total time cannot be less than: {SHORTEST_AMOUNT_OF_TIME}!" + raise InvalidTimeParameterError(msg) return total_time def _timed_out(self, session: Session, start_time: float, count: int) -> bool: @@ -262,11 +262,10 @@ def __pytest_loop_step_number(self, request: pytest.FixtureRequest) -> int: return request.param except AttributeError: if issubclass(request.cls, TestCase): - warnings.warn("Repeating unittest class tests not supported") + warnings.warn("Repeating unittest class tests not supported", stacklevel=2) else: - raise UnexpectedError( - "This call couldn't work with pytest-loops. Please consider raising an issue with your usage." - ) + msg = "This call couldn't work with pytest-loops. Please consider raising an issue with your usage." + raise UnexpectedError(msg) from None return count @pytest.hookimpl(trylast=True) diff --git a/codeflash/verification/test_results.py b/codeflash/verification/test_results.py index 28d8bfc0d..a4ecea816 100644 --- a/codeflash/verification/test_results.py +++ b/codeflash/verification/test_results.py @@ -1,10 +1,9 @@ from __future__ import annotations import sys -from collections.abc import Iterator from enum import Enum from pathlib import Path -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast from pydantic import BaseModel from pydantic.dataclasses import dataclass @@ -13,6 +12,9 @@ from codeflash.cli_cmds.console import DEBUG_MODE, logger from codeflash.verification.comparator import comparator +if TYPE_CHECKING: + from collections.abc import Iterator + class VerificationType(str, Enum): FUNCTION_CALL = ( @@ -53,7 +55,11 @@ class InvocationId: # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id def id(self) -> str: - return f"{self.test_module_path}:{(self.test_class_name + '.' if self.test_class_name else '')}{self.test_function_name}:{self.function_getting_tested}:{self.iteration_id}" + class_prefix = f"{self.test_class_name}." if self.test_class_name else "" + return ( + f"{self.test_module_path}:{class_prefix}{self.test_function_name}:" + f"{self.function_getting_tested}:{self.iteration_id}" + ) @staticmethod def from_str_id(string_id: str, iteration_id: Optional[str] = None) -> InvocationId: @@ -167,9 +173,13 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree: def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: for result in self.test_results: if result.did_pass and not result.runtime: - logger.debug( - f"Ignoring test case that passed but had no runtime -> {result.id}, Loop # {result.loop_index}, Test Type: {result.test_type}, Verification Type: {result.verification_type}" + msg = ( + f"Ignoring test case that passed but had no runtime -> {result.id}, " + f"Loop # {result.loop_index}, Test Type: {result.test_type}, " + f"Verification Type: {result.verification_type}" ) + logger.debug(msg) + usable_runtimes = [ (result.id, result.runtime) for result in self.test_results if result.did_pass and result.runtime ] @@ -179,16 +189,14 @@ def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: } def total_passed_runtime(self) -> int: - """Calculate the sum of runtimes of all test cases that passed, where a testcase runtime - is the minimum value of all looped execution runtimes. + """Calculate the sum of runtimes of all test cases that passed. + + A testcase runtime is the minimum value of all looped execution runtimes. :return: The runtime in nanoseconds. """ return sum( - [ - min(usable_runtime_data) - for invocation_id, usable_runtime_data in self.usable_runtime_data_by_test_case().items() - ] + [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] ) def __iter__(self) -> Iterator[FunctionTestInvocation]: diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 46203e65a..fcac7e756 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -30,6 +30,7 @@ def run_behavioral_tests( test_framework: str, test_env: dict[str, str], cwd: Path, + *, pytest_timeout: int | None = None, pytest_cmd: str = "pytest", verbose: bool = False, @@ -59,7 +60,7 @@ def run_behavioral_tests( "--codeflash_loops_scope=session", "--codeflash_min_loops=1", "--codeflash_max_loops=1", - f"--codeflash_seconds={pytest_target_runtime_seconds}", # TODO :This is unnecessary, update the plugin to not ask for this + f"--codeflash_seconds={pytest_target_runtime_seconds}", # TODO : This is unnecessary, update the plugin to not ask for this # noqa: E501 ] result_file_path = get_run_tmp_file(Path("pytest_results.xml")) @@ -77,18 +78,17 @@ def run_behavioral_tests( # then the current run will be appended to the previous data, which skews the results logger.debug(cov_erase) + coverage_cmd = f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m" results = execute_test_subprocess( - shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m") - + pytest_cmd_list - + common_pytest_args - + result_args - + test_files, + shlex.split(coverage_cmd) + pytest_cmd_list + common_pytest_args + result_args + test_files, cwd=cwd, env=pytest_test_env, timeout=600, ) logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") + f"Result return code: {results.returncode}" + + (f", Result stderr: {results.stderr}" if results.stderr else "") + ) else: results = execute_test_subprocess( pytest_cmd_list + common_pytest_args + result_args + test_files, @@ -97,17 +97,25 @@ def run_behavioral_tests( timeout=600, # TODO: Make this dynamic ) logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") + f"Result return code: {results.returncode}" + + (f", Result stderr: {results.stderr}" if results.stderr else "") + ) elif test_framework == "unittest": if enable_coverage: - raise ValueError("Coverage is not supported yet for unittest framework") + msg = "Coverage is not supported yet for unittest framework" + raise ValueError(msg) test_env["CODEFLASH_LOOP_INDEX"] = "1" test_files = [file.instrumented_behavior_file_path for file in test_paths.test_files] - result_file_path, results = run_unittest_tests(verbose, test_files, test_env, cwd) + result_file_path, results = run_unittest_tests( + verbose=verbose, test_file_paths=test_files, test_env=test_env, cwd=cwd + ) logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") + f"Result return code: {results.returncode}" + + (f", Result stderr: {results.stderr}" if results.stderr else "") + ) else: - raise ValueError(f"Unsupported test framework: {test_framework}") + msg = f"Unsupported test framework: {test_framework}" + raise ValueError(msg) return result_file_path, results, coverage_database_file if enable_coverage else None @@ -118,12 +126,13 @@ def run_benchmarking_tests( test_env: dict[str, str], cwd: Path, test_framework: str, + *, pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME, verbose: bool = False, pytest_timeout: int | None = None, pytest_min_loops: int = 5, pytest_max_loops: int = 100_000, -): +) -> tuple[Path, subprocess.CompletedProcess]: if test_framework == "pytest": pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX) test_files: list[str] = [] @@ -162,13 +171,18 @@ def run_benchmarking_tests( ) elif test_framework == "unittest": test_files = [file.benchmarking_file_path for file in test_paths.test_files] - result_file_path, results = run_unittest_tests(verbose, test_files, test_env, cwd) + result_file_path, results = run_unittest_tests( + verbose=verbose, test_file_paths=test_files, test_env=test_env, cwd=cwd + ) else: - raise ValueError(f"Unsupported test framework: {test_framework}") + msg = f"Unsupported test framework: {test_framework}" + raise ValueError(msg) return result_file_path, results -def run_unittest_tests(verbose: bool, test_file_paths: list[Path], test_env: dict[str, str], cwd: Path): +def run_unittest_tests( + *, verbose: bool, test_file_paths: list[Path], test_env: dict[str, str], cwd: Path +) -> tuple[Path, subprocess.CompletedProcess]: result_file_path = get_run_tmp_file(Path("unittest_results.xml")) unittest_cmd_list = [SAFE_SYS_EXECUTABLE, "-m", "xmlrunner"] log_level = ["-v"] if verbose else [] diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index 3d30f89f9..79f1b9656 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -31,33 +31,34 @@ def delete_multiple_if_name_main(test_ast: ast.Module) -> ast.Module: class ModifyInspiredTests(ast.NodeTransformer): - """This isn't being used right now""" + """Transformer for modifying inspired test classes. - def __init__(self, import_list, test_framework): + Class is currently not in active use. + """ + + def __init__(self, import_list: list[ast.AST], test_framework: str) -> None: self.import_list = import_list self.test_framework = test_framework - def visit_Import(self, node: ast.Import): + def visit_Import(self, node: ast.Import) -> None: self.import_list.append(node) - def visit_ImportFrom(self, node: ast.ImportFrom): + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: self.import_list.append(node) - def visit_ClassDef(self, node: ast.ClassDef): + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: if self.test_framework != "unittest": return node found = False if node.bases: for base in node.bases: - if isinstance(base, ast.Attribute): - if base.attr == "TestCase" and base.value.id == "unittest": - found = True - break - if isinstance(base, ast.Name): - # TODO: Possibility that this is not a unittest.TestCase - if base.id == "TestCase": - found = True - break + if isinstance(base, ast.Attribute) and base.attr == "TestCase" and base.value.id == "unittest": + found = True + break + # TODO: Check if this is actually a unittest.TestCase + if isinstance(base, ast.Name) and base.id == "TestCase": + found = True + break if not found: return node node.name = node.name + "Inspired" diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index f60dcc5b6..aba8f956e 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -71,7 +71,7 @@ def merge_unit_tests(unit_test_source: str, inspired_unit_tests: str, test_frame except SyntaxError as e: logger.exception(f"Syntax error in code: {e}") return unit_test_source - import_list: list[ast.stmt] = list() + import_list: list[ast.stmt] = [] modified_ast = ModifyInspiredTests(import_list, test_framework).visit(inspired_unit_tests_ast) if test_framework == "pytest": # Because we only want to modify the top level test functions diff --git a/pyproject.toml b/pyproject.toml index 706aec52a..a4c023ed6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,10 +163,11 @@ ignore = [ "D103", "D105", "D107", + "D203", # incorrect-blank-line-before-class (incompatible with D211) + "D213", # multi-line-summary-second-line (incompatible with D212) "S101", "S603", "S607", - "ANN101", "COM812", "FIX002", "PLR0912", @@ -176,13 +177,14 @@ ignore = [ "TD003", "TD004", "PLR2004", - "UP007" + "UP007", + "N802", # we use a lot of stdlib which follows this convention ] [tool.ruff.lint.flake8-type-checking] strict = true runtime-evaluated-base-classes = ["pydantic.BaseModel"] -runtime-evaluated-decorators = ["pydantic.validate_call"] +runtime-evaluated-decorators = ["pydantic.validate_call", "pydantic.dataclasses.dataclass"] [tool.ruff.lint.pep8-naming] classmethod-decorators = [ @@ -190,6 +192,9 @@ classmethod-decorators = [ "pydantic.validator", ] +[tool.ruff.lint.isort] +split-on-trailing-comma = false + [tool.ruff.format] docstring-code-format = true skip-magic-trailing-comma = true diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 7f4a94845..88b46e87c 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest + from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent diff --git a/tests/test_comparator.py b/tests/test_comparator.py index aad9a0e30..0fc292f09 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -10,7 +10,6 @@ import pydantic import pytest -from pathlib import Path from codeflash.either import Failure, Success from codeflash.verification.comparator import comparator @@ -302,7 +301,7 @@ def test_numpy(): def test_scipy(): try: - import scipy as sp # type: ignore + import scipy as sp # type: ignore except ImportError: pytest.skip() a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) @@ -472,7 +471,7 @@ def test_pandas(): def test_pyrsistent(): try: - from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore + from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore except ImportError: pytest.skip() @@ -1039,7 +1038,7 @@ def raise_specific_exception(): assert not comparator(..., None) - assert not comparator(Ellipsis, None) + assert not comparator(Ellipsis, None) code7 = "a = 1 + 2" module7 = ast.parse(code7) @@ -1053,4 +1052,4 @@ def raise_specific_exception(): module2 = ast.parse(code2) - assert not comparator(module7, module2) \ No newline at end of file + assert not comparator(module7, module2) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 59e3f3ca0..5c0a91c38 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -3,6 +3,7 @@ from pathlib import Path import pytest + from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.formatter import format_code, sort_imports diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index faa754af9..8534cb803 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import pytest + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import is_successful from codeflash.models.models import FunctionParent diff --git a/tests/test_get_read_writable_code.py b/tests/test_get_read_writable_code.py index 1680f2403..5040eabe2 100644 --- a/tests/test_get_read_writable_code.py +++ b/tests/test_get_read_writable_code.py @@ -1,6 +1,7 @@ from textwrap import dedent import pytest + from codeflash.context.code_context_extractor import get_read_writable_code diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index ee237cfca..a35556dfd 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -6,6 +6,7 @@ from pathlib import Path import isort + from code_to_optimize.bubble_sort_method import BubbleSorter from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.discovery.functions_to_optimize import FunctionToOptimize diff --git a/tests/test_lru_cache_clear.py b/tests/test_lru_cache_clear.py index ecf385fed..43c08a0ed 100644 --- a/tests/test_lru_cache_clear.py +++ b/tests/test_lru_cache_clear.py @@ -4,12 +4,12 @@ import pytest from _pytest.config import Config -from codeflash.verification.pytest_plugin import PyTest_Loops +from codeflash.verification.pytest_plugin import PytestLoops @pytest.fixture -def pytest_loops_instance(pytestconfig: Config) -> PyTest_Loops: - return PyTest_Loops(pytestconfig) +def pytest_loops_instance(pytestconfig: Config) -> PytestLoops: + return PytestLoops(pytestconfig) @pytest.fixture @@ -27,7 +27,7 @@ def create_mock_module(module_name: str, source_code: str) -> types.ModuleType: return module -def test_clear_lru_caches_function(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None: +def test_clear_lru_caches_function(pytest_loops_instance: PytestLoops, mock_item: type) -> None: source_code = """ import functools @@ -46,7 +46,7 @@ def my_func(x): assert mock_module.my_func.cache_info().currsize == 0 -def test_clear_lru_caches_class_method(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None: +def test_clear_lru_caches_class_method(pytest_loops_instance: PytestLoops, mock_item: type) -> None: source_code = """ import functools @@ -67,7 +67,7 @@ def my_method(self, x): assert mock_module.MyClass.my_method.cache_info().currsize == 0 -def test_clear_lru_caches_exception_handling(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None: +def test_clear_lru_caches_exception_handling(pytest_loops_instance: PytestLoops, mock_item: type) -> None: """Test that exceptions during clearing are handled.""" class BrokenCache: @@ -79,7 +79,7 @@ def cache_clear(self) -> NoReturn: pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001 -def test_clear_lru_caches_no_cache(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None: +def test_clear_lru_caches_no_cache(pytest_loops_instance: PytestLoops, mock_item: type) -> None: def no_cache_func(x: int) -> int: return x diff --git a/tests/test_remove_functions_from_generated_tests.py b/tests/test_remove_functions_from_generated_tests.py index 0e926a14d..dc2a14468 100644 --- a/tests/test_remove_functions_from_generated_tests.py +++ b/tests/test_remove_functions_from_generated_tests.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest + from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests from codeflash.models.models import GeneratedTests, GeneratedTestsList diff --git a/tests/test_validate_python_code.py b/tests/test_validate_python_code.py index 7fa760526..70128662d 100644 --- a/tests/test_validate_python_code.py +++ b/tests/test_validate_python_code.py @@ -1,7 +1,9 @@ -from codeflash.models.models import CodeString import pytest from pydantic import ValidationError +from codeflash.models.models import CodeString + + def test_python_string(): code = CodeString(code="print('Hello, World!')") assert code.code == "print('Hello, World!')" @@ -38,4 +40,4 @@ def test_whitespace_only(): # Whitespace is still syntactically valid (no-op) whitespace_code = " " cs = CodeString(code=whitespace_code) - assert cs.code == whitespace_code \ No newline at end of file + assert cs.code == whitespace_code