diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index a1e8c12eb..1653a6293 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -12,7 +12,8 @@ import sentry_sdk from codeflash.cli_cmds.console import logger -from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError +from codeflash.picklepatch.pickle_placeholder import \ + PicklePlaceholderAccessError try: import numpy as np @@ -64,6 +65,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent.""" try: + if orig is new: + return True if type(orig) is not type(new): type_obj = type(orig) new_type_obj = type(new) @@ -73,7 +76,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if isinstance(orig, (list, tuple)): if len(orig) != len(new): return False - return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) + for elem1, elem2 in zip(orig, new): + if not comparator(elem1, elem2, superset_obj): + return False + return True if isinstance( orig, @@ -139,7 +145,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)): if superset_obj: - return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items()) + for k, v in orig.items(): + if k not in new or not comparator(v, new[k], superset_obj): + return False + return True if len(orig) != len(new): return False for key in orig: @@ -158,7 +167,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return np.allclose(orig, new, equal_nan=True) except Exception: # fails at "ufunc 'isfinite' not supported for the input types" - return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)]) + for x, y in zip(orig, new): + if not comparator(x, y, superset_obj): + return False + return True if HAS_NUMPY and isinstance(orig, (np.floating, np.complex64, np.complex128)): return np.isclose(orig, new) @@ -169,7 +181,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if HAS_NUMPY and isinstance(orig, np.void): if orig.dtype != new.dtype: return False - return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) + for field in orig.dtype.fields: + if not comparator(orig[field], new[field], superset_obj): + return False + return True if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): if orig.dtype != new.dtype: @@ -193,7 +208,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False if len(orig) != len(new): return False - return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) + for elem1, elem2 in zip(orig, new): + if not comparator(elem1, elem2, superset_obj): + return False + return True # This should be at the end of all numpy checking try: @@ -262,7 +280,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if superset_obj: # allow new object to be a superset of the original object - return all(k in new_keys and comparator(v, new_keys[k], superset_obj) for k, v in orig_keys.items()) + for k, v in orig_keys.items(): + if k not in new_keys or not comparator(v, new_keys[k], superset_obj): + return False + return True if isinstance(orig, ast.AST): orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"}