Skip to content

Commit f884976

Browse files
committed
don't exclude comparator
1 parent f857569 commit f884976

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

codeflash/verification/comparator.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#ruff: noqa: PGH003
12
import array
23
import ast
34
import datetime
@@ -20,40 +21,40 @@
2021
except ImportError:
2122
HAS_NUMPY = False
2223
try:
23-
import sqlalchemy
24+
import sqlalchemy # type: ignore
2425

2526
HAS_SQLALCHEMY = True
2627
except ImportError:
2728
HAS_SQLALCHEMY = False
2829
try:
29-
import scipy
30+
import scipy # type: ignore
3031

3132
HAS_SCIPY = True
3233
except ImportError:
3334
HAS_SCIPY = False
3435

3536
try:
36-
import pandas
37+
import pandas # type: ignore # noqa: ICN001
3738

3839
HAS_PANDAS = True
3940
except ImportError:
4041
HAS_PANDAS = False
4142

4243
try:
43-
import pyrsistent
44+
import pyrsistent # type: ignore
4445

4546
HAS_PYRSISTENT = True
4647
except ImportError:
4748
HAS_PYRSISTENT = False
4849
try:
49-
import torch
50+
import torch # type: ignore
5051

5152
HAS_TORCH = True
5253
except ImportError:
5354
HAS_TORCH = False
5455

5556

56-
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
57+
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
5758
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
5859
try:
5960
if type(orig) is not type(new):
@@ -108,22 +109,22 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
108109
if HAS_SQLALCHEMY:
109110
try:
110111
insp = sqlalchemy.inspection.inspect(orig)
111-
insp = sqlalchemy.inspection.inspect(new)
112+
insp = sqlalchemy.inspection.inspect(new) # noqa: F841
112113
orig_keys = orig.__dict__
113114
new_keys = new.__dict__
114115
for key in list(orig_keys.keys()):
115116
if key.startswith("_"):
116117
continue
117118
if key not in new_keys or not comparator(orig_keys[key], new_keys[key], superset_obj):
118119
return False
119-
return True
120+
return True # noqa: TRY300
120121

121122
except sqlalchemy.exc.NoInspectionAvailable:
122123
pass
123124
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
124125
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
125126
if superset_obj:
126-
return all(k in new.keys() and comparator(v, new[k], superset_obj) for k, v in orig.items())
127+
return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items())
127128
if len(orig) != len(new):
128129
return False
129130
for key in orig:
@@ -183,12 +184,12 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
183184
try:
184185
if HAS_NUMPY and np.isnan(orig):
185186
return np.isnan(new)
186-
except Exception:
187+
except Exception: # noqa: S110
187188
pass
188189
try:
189190
if HAS_NUMPY and np.isinf(orig):
190191
return np.isinf(new)
191-
except Exception:
192+
except Exception: # noqa: S110
192193
pass
193194

194195
if HAS_TORCH and isinstance(orig, torch.Tensor):
@@ -228,14 +229,14 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
228229
try:
229230
if hasattr(orig, "__eq__") and str(type(orig.__eq__)) == "<class 'method'>":
230231
return orig == new
231-
except Exception:
232+
except Exception: # noqa: S110
232233
pass
233234

234235
# For class objects
235236
if hasattr(orig, "__dict__") and hasattr(new, "__dict__"):
236237
orig_keys = orig.__dict__
237238
new_keys = new.__dict__
238-
if type(orig_keys) == types.MappingProxyType and type(new_keys) == types.MappingProxyType:
239+
if type(orig_keys) == types.MappingProxyType and type(new_keys) == types.MappingProxyType: # noqa: E721
239240
# meta class objects
240241
if orig != new:
241242
return False
@@ -259,7 +260,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
259260
return True
260261
# TODO : Add other types here
261262
logger.warning(f"Unknown comparator input type: {type(orig)}")
262-
return False
263+
return False # noqa: TRY300
263264
except RecursionError as e:
264265
logger.error(f"RecursionError while comparing objects: {e}")
265266
sentry_sdk.capture_exception(e)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ warn_required_dynamic_aliases = true
154154
line-length = 120
155155
fix = true
156156
show-fixes = true
157-
exclude = ["code_to_optimize/", "pie_test_set/", "tests/", "codeflash/verification/comparator.py"]
157+
exclude = ["code_to_optimize/", "pie_test_set/", "tests/"]
158158

159159
[tool.ruff.lint]
160160
select = ["ALL"]

0 commit comments

Comments
 (0)