Skip to content

Commit 2c07d6d

Browse files
committed
optimize comparator
1 parent e973c69 commit 2c07d6d

File tree

1 file changed

+23
-48
lines changed

1 file changed

+23
-48
lines changed

codeflash/verification/comparator.py

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# ruff: noqa: PGH003
21
import array
32
import ast
43
import datetime
@@ -8,58 +7,21 @@
87
import re
98
import types
109
from collections import ChainMap, OrderedDict, deque
10+
from importlib.util import find_spec
1111
from typing import Any
1212

1313
import sentry_sdk
1414

1515
from codeflash.cli_cmds.console import logger
1616
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
1717

18-
try:
19-
import numpy as np
20-
21-
HAS_NUMPY = True
22-
except ImportError:
23-
HAS_NUMPY = False
24-
try:
25-
import sqlalchemy # type: ignore
26-
27-
HAS_SQLALCHEMY = True
28-
except ImportError:
29-
HAS_SQLALCHEMY = False
30-
try:
31-
import scipy # type: ignore
32-
33-
HAS_SCIPY = True
34-
except ImportError:
35-
HAS_SCIPY = False
36-
37-
try:
38-
import pandas # type: ignore # noqa: ICN001
39-
40-
HAS_PANDAS = True
41-
except ImportError:
42-
HAS_PANDAS = False
43-
44-
try:
45-
import pyrsistent # type: ignore
46-
47-
HAS_PYRSISTENT = True
48-
except ImportError:
49-
HAS_PYRSISTENT = False
50-
try:
51-
import torch # type: ignore
52-
53-
HAS_TORCH = True
54-
except ImportError:
55-
HAS_TORCH = False
56-
try:
57-
import jax # type: ignore
58-
import jax.numpy as jnp # type: ignore
59-
60-
HAS_JAX = True
61-
except ImportError:
62-
HAS_JAX = False
18+
HAS_NUMPY = find_spec("numpy") is not None
19+
HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None
20+
HAS_SCIPY = find_spec("scipy") is not None
21+
HAS_PANDAS = find_spec("pandas") is not None
22+
HAS_PYRSISTENT = find_spec("pyrsistent") is not None
23+
HAS_TORCH = find_spec("torch") is not None
24+
HAS_JAX = find_spec("jax") is not None
6325

6426

6527
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -114,7 +76,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
11476
orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")}
11577
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
11678
return comparator(orig_dict, new_dict, superset_obj)
117-
79+
if HAS_JAX:
80+
import jax # type: ignore # noqa: PGH003
81+
import jax.numpy as jnp # type: ignore # noqa: PGH003
11882
# Handle JAX arrays first to avoid boolean context errors in other conditions
11983
if HAS_JAX and isinstance(orig, jax.Array):
12084
if orig.dtype != new.dtype:
@@ -123,6 +87,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
12387
return False
12488
return bool(jnp.allclose(orig, new, equal_nan=True))
12589

90+
if HAS_SQLALCHEMY:
91+
import sqlalchemy # type: ignore # noqa: PGH003
12692
if HAS_SQLALCHEMY:
12793
try:
12894
insp = sqlalchemy.inspection.inspect(orig)
@@ -138,6 +104,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
138104

139105
except sqlalchemy.exc.NoInspectionAvailable:
140106
pass
107+
if HAS_SCIPY:
108+
import scipy # type: ignore # noqa: PGH003
141109
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
142110
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
143111
if superset_obj:
@@ -151,6 +119,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
151119
return False
152120
return True
153121

122+
if HAS_NUMPY:
123+
import numpy as np # type: ignore # noqa: PGH003
154124
if HAS_NUMPY and isinstance(orig, np.ndarray):
155125
if orig.dtype != new.dtype:
156126
return False
@@ -180,6 +150,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
180150
return False
181151
return (orig != new).nnz == 0
182152

153+
if HAS_PANDAS:
154+
import pandas # type: ignore # noqa: ICN001, PGH003
183155
if HAS_PANDAS and isinstance(
184156
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
185157
):
@@ -209,6 +181,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
209181
except Exception: # noqa: S110
210182
pass
211183

184+
if HAS_TORCH:
185+
import torch # type: ignore # noqa: PGH003
212186
if HAS_TORCH and isinstance(orig, torch.Tensor):
213187
if orig.dtype != new.dtype:
214188
return False
@@ -219,7 +193,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
219193
if orig.device != new.device:
220194
return False
221195
return torch.allclose(orig, new, equal_nan=True)
222-
196+
if HAS_PYRSISTENT:
197+
import pyrsistent # type: ignore # noqa: PGH003
223198
if HAS_PYRSISTENT and isinstance(
224199
orig,
225200
(

0 commit comments

Comments
 (0)