Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import sentry_sdk

from codeflash.cli_cmds.console import logger
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
from codeflash.picklepatch.pickle_placeholder import \
PicklePlaceholderAccessError

try:
import numpy as np
Expand Down Expand Up @@ -64,6 +65,8 @@
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
try:
if orig is new:
return True
if type(orig) is not type(new):
type_obj = type(orig)
new_type_obj = type(new)
Expand All @@ -73,7 +76,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if isinstance(orig, (list, tuple)):
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
for elem1, elem2 in zip(orig, new):
if not comparator(elem1, elem2, superset_obj):
return False
return True

if isinstance(
orig,
Expand Down Expand Up @@ -139,7 +145,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
if superset_obj:
return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items())
for k, v in orig.items():
if k not in new or not comparator(v, new[k], superset_obj):
return False
return True
if len(orig) != len(new):
return False
for key in orig:
Expand All @@ -158,7 +167,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return np.allclose(orig, new, equal_nan=True)
except Exception:
# fails at "ufunc 'isfinite' not supported for the input types"
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
for x, y in zip(orig, new):
if not comparator(x, y, superset_obj):
return False
return True

if HAS_NUMPY and isinstance(orig, (np.floating, np.complex64, np.complex128)):
return np.isclose(orig, new)
Expand All @@ -169,7 +181,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if HAS_NUMPY and isinstance(orig, np.void):
if orig.dtype != new.dtype:
return False
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
for field in orig.dtype.fields:
if not comparator(orig[field], new[field], superset_obj):
return False
return True

if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
if orig.dtype != new.dtype:
Expand All @@ -193,7 +208,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
for elem1, elem2 in zip(orig, new):
if not comparator(elem1, elem2, superset_obj):
return False
return True

# This should be at the end of all numpy checking
try:
Expand Down Expand Up @@ -262,7 +280,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001

if superset_obj:
# allow new object to be a superset of the original object
return all(k in new_keys and comparator(v, new_keys[k], superset_obj) for k, v in orig_keys.items())
for k, v in orig_keys.items():
if k not in new_keys or not comparator(v, new_keys[k], superset_obj):
return False
return True

if isinstance(orig, ast.AST):
orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"}
Expand Down
Loading