Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import ast
import enum
import hashlib
import os
import pickle
Expand All @@ -11,12 +12,11 @@
import unittest
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, final

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize

import pytest
from pydantic.dataclasses import dataclass
from rich.panel import Panel
from rich.text import Text
Expand All @@ -35,6 +35,22 @@
from codeflash.verification.verification_utils import TestConfig


@final
class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this
#: Tests passed.
OK = 0
#: Tests failed.
TESTS_FAILED = 1
#: pytest was interrupted.
INTERRUPTED = 2
#: An internal error got in the way.
INTERNAL_ERROR = 3
#: pytest was misused.
USAGE_ERROR = 4
#: pytest couldn't find tests.
NO_TESTS_COLLECTED = 5


@dataclass(frozen=True)
class TestFunction:
function_name: str
Expand Down Expand Up @@ -412,15 +428,15 @@ def discover_tests_pytest(
error_section = match.group(1) if match else result.stdout

logger.warning(
f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}\n {error_section}"
f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}"
)
if "ModuleNotFoundError" in result.stdout:
match = ImportErrorPattern.search(result.stdout).group()
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
console.print(panel)

elif 0 <= exitcode <= 5:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}")
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}")
else:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
console.rule()
Expand Down
204 changes: 92 additions & 112 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# ruff: noqa: PGH003
import array
import ast
import datetime
Expand All @@ -8,65 +7,22 @@
import re
import types
from collections import ChainMap, OrderedDict, deque
from importlib.util import find_spec
from typing import Any

import sentry_sdk

from codeflash.cli_cmds.console import logger
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError

try:
import numpy as np

HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
try:
import sqlalchemy # type: ignore

HAS_SQLALCHEMY = True
except ImportError:
HAS_SQLALCHEMY = False
try:
import scipy # type: ignore

HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False

try:
import pandas # type: ignore # noqa: ICN001

HAS_PANDAS = True
except ImportError:
HAS_PANDAS = False

try:
import pyrsistent # type: ignore

HAS_PYRSISTENT = True
except ImportError:
HAS_PYRSISTENT = False
try:
import torch # type: ignore

HAS_TORCH = True
except ImportError:
HAS_TORCH = False
try:
import jax # type: ignore
import jax.numpy as jnp # type: ignore

HAS_JAX = True
except ImportError:
HAS_JAX = False

try:
import xarray # type: ignore

HAS_XARRAY = True
except ImportError:
HAS_XARRAY = False
HAS_NUMPY = find_spec("numpy") is not None
HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None
HAS_SCIPY = find_spec("scipy") is not None
HAS_PANDAS = find_spec("pandas") is not None
HAS_PYRSISTENT = find_spec("pyrsistent") is not None
HAS_TORCH = find_spec("torch") is not None
HAS_JAX = find_spec("jax") is not None
HAS_XARRAY = find_spec("xarray") is not None


def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
Expand Down Expand Up @@ -122,19 +78,28 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
return comparator(orig_dict, new_dict, superset_obj)

# Handle JAX arrays first to avoid boolean context errors in other conditions
if HAS_JAX and isinstance(orig, jax.Array):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
return bool(jnp.allclose(orig, new, equal_nan=True))
if HAS_JAX:
import jax # type: ignore # noqa: PGH003
import jax.numpy as jnp # type: ignore # noqa: PGH003

# Handle JAX arrays first to avoid boolean context errors in other conditions
if isinstance(orig, jax.Array):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
return bool(jnp.allclose(orig, new, equal_nan=True))

# Handle xarray objects before numpy to avoid boolean context errors
if HAS_XARRAY and isinstance(orig, (xarray.Dataset, xarray.DataArray)):
return orig.identical(new)
if HAS_XARRAY:
import xarray # type: ignore # noqa: PGH003

if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
return orig.identical(new)

if HAS_SQLALCHEMY:
import sqlalchemy # type: ignore # noqa: PGH003

try:
insp = sqlalchemy.inspection.inspect(orig)
insp = sqlalchemy.inspection.inspect(new) # noqa: F841
Expand All @@ -149,6 +114,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001

except sqlalchemy.exc.NoInspectionAvailable:
pass

if HAS_SCIPY:
import scipy # type: ignore # noqa: PGH003
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
if superset_obj:
Expand All @@ -162,27 +130,30 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
return True

if HAS_NUMPY and isinstance(orig, np.ndarray):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
try:
return np.allclose(orig, new, equal_nan=True)
except Exception:
# fails at "ufunc 'isfinite' not supported for the input types"
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
if HAS_NUMPY:
import numpy as np # type: ignore # noqa: PGH003

if HAS_NUMPY and isinstance(orig, (np.floating, np.complex64, np.complex128)):
return np.isclose(orig, new)
if isinstance(orig, np.ndarray):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
try:
return np.allclose(orig, new, equal_nan=True)
except Exception:
# fails at "ufunc 'isfinite' not supported for the input types"
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])

if HAS_NUMPY and isinstance(orig, (np.integer, np.bool_, np.byte)):
return orig == new
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
return np.isclose(orig, new)

if HAS_NUMPY and isinstance(orig, np.void):
if orig.dtype != new.dtype:
return False
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
if isinstance(orig, (np.integer, np.bool_, np.byte)):
return orig == new

if isinstance(orig, np.void):
if orig.dtype != new.dtype:
return False
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)

if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
if orig.dtype != new.dtype:
Expand All @@ -191,15 +162,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
return (orig != new).nnz == 0

if HAS_PANDAS and isinstance(
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
):
return orig.equals(new)
if HAS_PANDAS:
import pandas # type: ignore # noqa: ICN001, PGH003

if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
return orig == new
if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new):
return True
if isinstance(
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
):
return orig.equals(new)

if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
return orig == new
if pandas.isna(orig) and pandas.isna(new):
return True

if isinstance(orig, array.array):
if orig.typecode != new.typecode:
Expand All @@ -220,31 +194,37 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
except Exception: # noqa: S110
pass

if HAS_TORCH and isinstance(orig, torch.Tensor):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
if orig.requires_grad != new.requires_grad:
return False
if orig.device != new.device:
return False
return torch.allclose(orig, new, equal_nan=True)
if HAS_TORCH:
import torch # type: ignore # noqa: PGH003

if HAS_PYRSISTENT and isinstance(
orig,
(
pyrsistent.PMap,
pyrsistent.PVector,
pyrsistent.PSet,
pyrsistent.PRecord,
pyrsistent.PClass,
pyrsistent.PBag,
pyrsistent.PList,
pyrsistent.PDeque,
),
):
return orig == new
if isinstance(orig, torch.Tensor):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
if orig.requires_grad != new.requires_grad:
return False
if orig.device != new.device:
return False
return torch.allclose(orig, new, equal_nan=True)

if HAS_PYRSISTENT:
import pyrsistent # type: ignore # noqa: PGH003

if isinstance(
orig,
(
pyrsistent.PMap,
pyrsistent.PVector,
pyrsistent.PSet,
pyrsistent.PRecord,
pyrsistent.PClass,
pyrsistent.PBag,
pyrsistent.PList,
pyrsistent.PDeque,
),
):
return orig == new

if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"):
orig_dict = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/end_to_end_test_bubblesort_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.40
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.30
)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)
Expand Down
Loading