1212import sentry_sdk
1313
1414from codeflash .cli_cmds .console import logger
15- from codeflash .picklepatch .pickle_placeholder import PicklePlaceholderAccessError
15+ from codeflash .picklepatch .pickle_placeholder import \
16+ PicklePlaceholderAccessError
1617
1718try :
1819 import numpy as np
6465def comparator (orig : Any , new : Any , superset_obj = False ) -> bool : # noqa: ANN001, ANN401, FBT002, PLR0911
6566 """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
6667 try :
68+ if orig is new :
69+ return True
6770 if type (orig ) is not type (new ):
6871 type_obj = type (orig )
6972 new_type_obj = type (new )
@@ -73,7 +76,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
7376 if isinstance (orig , (list , tuple )):
7477 if len (orig ) != len (new ):
7578 return False
76- return all (comparator (elem1 , elem2 , superset_obj ) for elem1 , elem2 in zip (orig , new ))
79+ for elem1 , elem2 in zip (orig , new ):
80+ if not comparator (elem1 , elem2 , superset_obj ):
81+ return False
82+ return True
7783
7884 if isinstance (
7985 orig ,
@@ -139,7 +145,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
139145 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
140146 if isinstance (orig , dict ) and not (HAS_SCIPY and isinstance (orig , scipy .sparse .spmatrix )):
141147 if superset_obj :
142- return all (k in new and comparator (v , new [k ], superset_obj ) for k , v in orig .items ())
148+ for k , v in orig .items ():
149+ if k not in new or not comparator (v , new [k ], superset_obj ):
150+ return False
151+ return True
143152 if len (orig ) != len (new ):
144153 return False
145154 for key in orig :
@@ -158,7 +167,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
158167 return np .allclose (orig , new , equal_nan = True )
159168 except Exception :
160169 # fails at "ufunc 'isfinite' not supported for the input types"
161- return np .all ([comparator (x , y , superset_obj ) for x , y in zip (orig , new )])
170+ for x , y in zip (orig , new ):
171+ if not comparator (x , y , superset_obj ):
172+ return False
173+ return True
162174
163175 if HAS_NUMPY and isinstance (orig , (np .floating , np .complex64 , np .complex128 )):
164176 return np .isclose (orig , new )
@@ -169,7 +181,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
169181 if HAS_NUMPY and isinstance (orig , np .void ):
170182 if orig .dtype != new .dtype :
171183 return False
172- return all (comparator (orig [field ], new [field ], superset_obj ) for field in orig .dtype .fields )
184+ for field in orig .dtype .fields :
185+ if not comparator (orig [field ], new [field ], superset_obj ):
186+ return False
187+ return True
173188
174189 if HAS_SCIPY and isinstance (orig , scipy .sparse .spmatrix ):
175190 if orig .dtype != new .dtype :
@@ -193,7 +208,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
193208 return False
194209 if len (orig ) != len (new ):
195210 return False
196- return all (comparator (elem1 , elem2 , superset_obj ) for elem1 , elem2 in zip (orig , new ))
211+ for elem1 , elem2 in zip (orig , new ):
212+ if not comparator (elem1 , elem2 , superset_obj ):
213+ return False
214+ return True
197215
198216 # This should be at the end of all numpy checking
199217 try :
@@ -262,7 +280,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
262280
263281 if superset_obj :
264282 # allow new object to be a superset of the original object
265- return all (k in new_keys and comparator (v , new_keys [k ], superset_obj ) for k , v in orig_keys .items ())
283+ for k , v in orig_keys .items ():
284+ if k not in new_keys or not comparator (v , new_keys [k ], superset_obj ):
285+ return False
286+ return True
266287
267288 if isinstance (orig , ast .AST ):
268289 orig_keys = {k : v for k , v in orig .__dict__ .items () if k != "parent" }
0 commit comments