diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index d4994d0c3..288763926 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -2,6 +2,7 @@ from __future__ import annotations import ast +import enum import hashlib import os import pickle @@ -11,12 +12,11 @@ import unittest from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, final if TYPE_CHECKING: from codeflash.discovery.functions_to_optimize import FunctionToOptimize -import pytest from pydantic.dataclasses import dataclass from rich.panel import Panel from rich.text import Text @@ -35,6 +35,22 @@ from codeflash.verification.verification_utils import TestConfig +@final +class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this + #: Tests passed. + OK = 0 + #: Tests failed. + TESTS_FAILED = 1 + #: pytest was interrupted. + INTERRUPTED = 2 + #: An internal error got in the way. + INTERNAL_ERROR = 3 + #: pytest was misused. + USAGE_ERROR = 4 + #: pytest couldn't find tests. + NO_TESTS_COLLECTED = 5 + + @dataclass(frozen=True) class TestFunction: function_name: str @@ -412,7 +428,7 @@ def discover_tests_pytest( error_section = match.group(1) if match else result.stdout logger.warning( - f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}\n {error_section}" + f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}" ) if "ModuleNotFoundError" in result.stdout: match = ImportErrorPattern.search(result.stdout).group() @@ -420,7 +436,7 @@ def discover_tests_pytest( console.print(panel) elif 0 <= exitcode <= 5: - logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}") + logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}") else: logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}") console.rule() diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 50a2fe33b..c78107106 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -1,4 +1,3 @@ -# ruff: noqa: PGH003 import array import ast import datetime @@ -8,6 +7,7 @@ import re import types from collections import ChainMap, OrderedDict, deque +from importlib.util import find_spec from typing import Any import sentry_sdk @@ -15,58 +15,14 @@ from codeflash.cli_cmds.console import logger from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError -try: - import numpy as np - - HAS_NUMPY = True -except ImportError: - HAS_NUMPY = False -try: - import sqlalchemy # type: ignore - - HAS_SQLALCHEMY = True -except ImportError: - HAS_SQLALCHEMY = False -try: - import scipy # type: ignore - - HAS_SCIPY = True -except ImportError: - HAS_SCIPY = False - -try: - import pandas # type: ignore # noqa: ICN001 - - HAS_PANDAS = True -except ImportError: - HAS_PANDAS = False - -try: - import pyrsistent # type: ignore - - HAS_PYRSISTENT = True -except ImportError: - HAS_PYRSISTENT = False -try: - import torch # type: ignore - - HAS_TORCH = True -except ImportError: - HAS_TORCH = False -try: - import jax # type: ignore - import jax.numpy as jnp # type: ignore - - HAS_JAX = True -except ImportError: - HAS_JAX = False - -try: - import xarray # type: ignore - - HAS_XARRAY = True -except ImportError: - HAS_XARRAY = False +HAS_NUMPY = find_spec("numpy") is not None +HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None +HAS_SCIPY = find_spec("scipy") is not None +HAS_PANDAS = find_spec("pandas") is not None +HAS_PYRSISTENT = find_spec("pyrsistent") is not None +HAS_TORCH = find_spec("torch") is not None +HAS_JAX = find_spec("jax") is not None +HAS_XARRAY = find_spec("xarray") is not None def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 @@ -122,19 +78,28 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")} return comparator(orig_dict, new_dict, superset_obj) - # Handle JAX arrays first to avoid boolean context errors in other conditions - if HAS_JAX and isinstance(orig, jax.Array): - if orig.dtype != new.dtype: - return False - if orig.shape != new.shape: - return False - return bool(jnp.allclose(orig, new, equal_nan=True)) + if HAS_JAX: + import jax # type: ignore # noqa: PGH003 + import jax.numpy as jnp # type: ignore # noqa: PGH003 + + # Handle JAX arrays first to avoid boolean context errors in other conditions + if isinstance(orig, jax.Array): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + return bool(jnp.allclose(orig, new, equal_nan=True)) # Handle xarray objects before numpy to avoid boolean context errors - if HAS_XARRAY and isinstance(orig, (xarray.Dataset, xarray.DataArray)): - return orig.identical(new) + if HAS_XARRAY: + import xarray # type: ignore # noqa: PGH003 + + if isinstance(orig, (xarray.Dataset, xarray.DataArray)): + return orig.identical(new) if HAS_SQLALCHEMY: + import sqlalchemy # type: ignore # noqa: PGH003 + try: insp = sqlalchemy.inspection.inspect(orig) insp = sqlalchemy.inspection.inspect(new) # noqa: F841 @@ -149,6 +114,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 except sqlalchemy.exc.NoInspectionAvailable: pass + + if HAS_SCIPY: + import scipy # type: ignore # noqa: PGH003 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)): if superset_obj: @@ -162,27 +130,30 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return True - if HAS_NUMPY and isinstance(orig, np.ndarray): - if orig.dtype != new.dtype: - return False - if orig.shape != new.shape: - return False - try: - return np.allclose(orig, new, equal_nan=True) - except Exception: - # fails at "ufunc 'isfinite' not supported for the input types" - return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)]) + if HAS_NUMPY: + import numpy as np # type: ignore # noqa: PGH003 - if HAS_NUMPY and isinstance(orig, (np.floating, np.complex64, np.complex128)): - return np.isclose(orig, new) + if isinstance(orig, np.ndarray): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + try: + return np.allclose(orig, new, equal_nan=True) + except Exception: + # fails at "ufunc 'isfinite' not supported for the input types" + return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)]) - if HAS_NUMPY and isinstance(orig, (np.integer, np.bool_, np.byte)): - return orig == new + if isinstance(orig, (np.floating, np.complex64, np.complex128)): + return np.isclose(orig, new) - if HAS_NUMPY and isinstance(orig, np.void): - if orig.dtype != new.dtype: - return False - return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) + if isinstance(orig, (np.integer, np.bool_, np.byte)): + return orig == new + + if isinstance(orig, np.void): + if orig.dtype != new.dtype: + return False + return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): if orig.dtype != new.dtype: @@ -191,15 +162,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return (orig != new).nnz == 0 - if HAS_PANDAS and isinstance( - orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray) - ): - return orig.equals(new) + if HAS_PANDAS: + import pandas # type: ignore # noqa: ICN001, PGH003 - if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)): - return orig == new - if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new): - return True + if isinstance( + orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray) + ): + return orig.equals(new) + + if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)): + return orig == new + if pandas.isna(orig) and pandas.isna(new): + return True if isinstance(orig, array.array): if orig.typecode != new.typecode: @@ -220,31 +194,37 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 except Exception: # noqa: S110 pass - if HAS_TORCH and isinstance(orig, torch.Tensor): - if orig.dtype != new.dtype: - return False - if orig.shape != new.shape: - return False - if orig.requires_grad != new.requires_grad: - return False - if orig.device != new.device: - return False - return torch.allclose(orig, new, equal_nan=True) + if HAS_TORCH: + import torch # type: ignore # noqa: PGH003 - if HAS_PYRSISTENT and isinstance( - orig, - ( - pyrsistent.PMap, - pyrsistent.PVector, - pyrsistent.PSet, - pyrsistent.PRecord, - pyrsistent.PClass, - pyrsistent.PBag, - pyrsistent.PList, - pyrsistent.PDeque, - ), - ): - return orig == new + if isinstance(orig, torch.Tensor): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + if orig.requires_grad != new.requires_grad: + return False + if orig.device != new.device: + return False + return torch.allclose(orig, new, equal_nan=True) + + if HAS_PYRSISTENT: + import pyrsistent # type: ignore # noqa: PGH003 + + if isinstance( + orig, + ( + pyrsistent.PMap, + pyrsistent.PVector, + pyrsistent.PSet, + pyrsistent.PRecord, + pyrsistent.PClass, + pyrsistent.PBag, + pyrsistent.PList, + pyrsistent.PDeque, + ), + ): + return orig == new if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"): orig_dict = {} diff --git a/tests/scripts/end_to_end_test_bubblesort_unittest.py b/tests/scripts/end_to_end_test_bubblesort_unittest.py index 074ba662f..3cf6f7303 100644 --- a/tests/scripts/end_to_end_test_bubblesort_unittest.py +++ b/tests/scripts/end_to_end_test_bubblesort_unittest.py @@ -6,7 +6,7 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( - file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.40 + file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.30 ) cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve() return run_codeflash_command(cwd, config, expected_improvement_pct)