diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index 05056828b..a445576e0 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -9,7 +9,7 @@ from functools import cache from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import git from rich.prompt import Confirm diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 1b8edcec7..e97bb45cd 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -53,6 +53,7 @@ class ProvideApiKeyParams: class OnPatchAppliedParams: patch_id: str + @dataclass class OptimizableFunctionsInCommitParams: commit_hash: str diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index a1e8c12eb..e6244b78a 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -7,6 +7,7 @@ import math import re import types +from collections import ChainMap, OrderedDict, deque from typing import Any import sentry_sdk @@ -70,7 +71,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 # distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__: return False - if isinstance(orig, (list, tuple)): + if isinstance(orig, (list, tuple, deque, ChainMap)): if len(orig) != len(new): return False return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) @@ -93,6 +94,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 enum.Enum, type, range, + OrderedDict, ), ): return orig == new diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 06e692b39..4e404a512 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -4,6 +4,8 @@ import datetime import decimal import re +from collections import ChainMap, Counter, UserDict, UserList, UserString, defaultdict, deque, namedtuple, OrderedDict + import sys import uuid from enum import Enum, Flag, IntFlag, auto @@ -1394,3 +1396,110 @@ def raise_specific_exception(): module2 = ast.parse(code2) assert not comparator(module7, module2) + +def test_collections() -> None: + # Deque + a = deque([1, 2, 3]) + b = deque([1, 2, 3]) + c = deque([1, 2, 4]) + d = deque([1, 2]) + e = [1, 2, 3] + f = deque([1, 2, 3], maxlen=5) + assert comparator(a, b) + assert comparator(a, f) # same elements, different maxlen is ok + assert not comparator(a, c) + assert not comparator(a, d) + assert not comparator(a, e) + + g = deque([{"a": 1}, {"b": 2}]) + h = deque([{"a": 1}, {"b": 2}]) + i = deque([{"a": 1}, {"b": 3}]) + assert comparator(g, h) + assert not comparator(g, i) + + empty_deque1 = deque() + empty_deque2 = deque() + assert comparator(empty_deque1, empty_deque2) + assert not comparator(empty_deque1, a) + + # namedtuple + Point = namedtuple('Point', ['x', 'y']) + a = Point(x=1, y=2) + b = Point(x=1, y=2) + c = Point(x=1, y=3) + assert comparator(a, b) + assert not comparator(a, c) + + Point2 = namedtuple('Point2', ['x', 'y']) + d = Point2(x=1, y=2) + assert not comparator(a, d) + + e = (1, 2) + assert not comparator(a, e) + + # ChainMap + map1 = {'a': 1, 'b': 2} + map2 = {'c': 3, 'd': 4} + a = ChainMap(map1, map2) + b = ChainMap(map1, map2) + c = ChainMap(map2, map1) + d = {'a': 1, 'b': 2, 'c': 3, 'd': 4} + assert comparator(a, b) + assert not comparator(a, c) + assert not comparator(a, d) + + # Counter + a = Counter(['a', 'b', 'a', 'c', 'b', 'a']) + b = Counter({'a': 3, 'b': 2, 'c': 1}) + c = Counter({'a': 3, 'b': 2, 'c': 2}) + d = {'a': 3, 'b': 2, 'c': 1} + assert comparator(a, b) + assert not comparator(a, c) + assert not comparator(a, d) + + # OrderedDict + a = OrderedDict([('a', 1), ('b', 2)]) + b = OrderedDict([('a', 1), ('b', 2)]) + c = OrderedDict([('b', 2), ('a', 1)]) + d = {'a': 1, 'b': 2} + assert comparator(a, b) + assert not comparator(a, c) + assert not comparator(a, d) + + # defaultdict + a = defaultdict(int, {'a': 1, 'b': 2}) + b = defaultdict(int, {'a': 1, 'b': 2}) + c = defaultdict(list, {'a': 1, 'b': 2}) + d = {'a': 1, 'b': 2} + e = defaultdict(int, {'a': 1, 'b': 3}) + assert comparator(a, b) + assert comparator(a, c) + assert not comparator(a, d) + assert not comparator(a, e) + + # UserDict + a = UserDict({'a': 1, 'b': 2}) + b = UserDict({'a': 1, 'b': 2}) + c = UserDict({'a': 1, 'b': 3}) + d = {'a': 1, 'b': 2} + assert comparator(a, b) + assert not comparator(a, c) + assert not comparator(a, d) + + # UserList + a = UserList([1, 2, 3]) + b = UserList([1, 2, 3]) + c = UserList([1, 2, 4]) + d = [1, 2, 3] + assert comparator(a, b) + assert not comparator(a, c) + assert not comparator(a, d) + + # UserString + a = UserString("hello") + b = UserString("hello") + c = UserString("world") + d = "hello" + assert comparator(a, b) + assert not comparator(a, c) + assert not comparator(a, d) \ No newline at end of file