1- # ruff: noqa: PGH003
21import array
32import ast
43import datetime
87import re
98import types
109from collections import ChainMap , OrderedDict , deque
10+ from importlib .util import find_spec
1111from typing import Any
1212
1313import sentry_sdk
1414
1515from codeflash .cli_cmds .console import logger
1616from codeflash .picklepatch .pickle_placeholder import PicklePlaceholderAccessError
1717
18- try :
19- import numpy as np
20-
21- HAS_NUMPY = True
22- except ImportError :
23- HAS_NUMPY = False
24- try :
25- import sqlalchemy # type: ignore
26-
27- HAS_SQLALCHEMY = True
28- except ImportError :
29- HAS_SQLALCHEMY = False
30- try :
31- import scipy # type: ignore
32-
33- HAS_SCIPY = True
34- except ImportError :
35- HAS_SCIPY = False
36-
37- try :
38- import pandas # type: ignore # noqa: ICN001
39-
40- HAS_PANDAS = True
41- except ImportError :
42- HAS_PANDAS = False
43-
44- try :
45- import pyrsistent # type: ignore
46-
47- HAS_PYRSISTENT = True
48- except ImportError :
49- HAS_PYRSISTENT = False
50- try :
51- import torch # type: ignore
52-
53- HAS_TORCH = True
54- except ImportError :
55- HAS_TORCH = False
56- try :
57- import jax # type: ignore
58- import jax .numpy as jnp # type: ignore
59-
60- HAS_JAX = True
61- except ImportError :
62- HAS_JAX = False
18+ HAS_NUMPY = find_spec ("numpy" ) is not None
19+ HAS_SQLALCHEMY = find_spec ("sqlalchemy" ) is not None
20+ HAS_SCIPY = find_spec ("scipy" ) is not None
21+ HAS_PANDAS = find_spec ("pandas" ) is not None
22+ HAS_PYRSISTENT = find_spec ("pyrsistent" ) is not None
23+ HAS_TORCH = find_spec ("torch" ) is not None
24+ HAS_JAX = find_spec ("jax" ) is not None
6325
6426
6527def comparator (orig : Any , new : Any , superset_obj = False ) -> bool : # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -114,7 +76,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
11476 orig_dict = {k : v for k , v in orig .__dict__ .items () if not k .startswith ("_" )}
11577 new_dict = {k : v for k , v in new .__dict__ .items () if not k .startswith ("_" )}
11678 return comparator (orig_dict , new_dict , superset_obj )
117-
79+ if HAS_JAX :
80+ import jax # type: ignore # noqa: PGH003
81+ import jax .numpy as jnp # type: ignore # noqa: PGH003
11882 # Handle JAX arrays first to avoid boolean context errors in other conditions
11983 if HAS_JAX and isinstance (orig , jax .Array ):
12084 if orig .dtype != new .dtype :
@@ -123,6 +87,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
12387 return False
12488 return bool (jnp .allclose (orig , new , equal_nan = True ))
12589
90+ if HAS_SQLALCHEMY :
91+ import sqlalchemy # type: ignore # noqa: PGH003
12692 if HAS_SQLALCHEMY :
12793 try :
12894 insp = sqlalchemy .inspection .inspect (orig )
@@ -138,6 +104,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
138104
139105 except sqlalchemy .exc .NoInspectionAvailable :
140106 pass
107+ if HAS_SCIPY :
108+ import scipy # type: ignore # noqa: PGH003
141109 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
142110 if isinstance (orig , dict ) and not (HAS_SCIPY and isinstance (orig , scipy .sparse .spmatrix )):
143111 if superset_obj :
@@ -151,6 +119,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
151119 return False
152120 return True
153121
122+ if HAS_NUMPY :
123+ import numpy as np # type: ignore # noqa: PGH003
154124 if HAS_NUMPY and isinstance (orig , np .ndarray ):
155125 if orig .dtype != new .dtype :
156126 return False
@@ -180,6 +150,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
180150 return False
181151 return (orig != new ).nnz == 0
182152
153+ if HAS_PANDAS :
154+ import pandas # type: ignore # noqa: ICN001, PGH003
183155 if HAS_PANDAS and isinstance (
184156 orig , (pandas .DataFrame , pandas .Series , pandas .Index , pandas .Categorical , pandas .arrays .SparseArray )
185157 ):
@@ -209,6 +181,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
209181 except Exception : # noqa: S110
210182 pass
211183
184+ if HAS_TORCH :
185+ import torch # type: ignore # noqa: PGH003
212186 if HAS_TORCH and isinstance (orig , torch .Tensor ):
213187 if orig .dtype != new .dtype :
214188 return False
@@ -219,7 +193,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
219193 if orig .device != new .device :
220194 return False
221195 return torch .allclose (orig , new , equal_nan = True )
222-
196+ if HAS_PYRSISTENT :
197+ import pyrsistent # type: ignore # noqa: PGH003
223198 if HAS_PYRSISTENT and isinstance (
224199 orig ,
225200 (
0 commit comments