Skip to content

Commit a0afba2

Browse files
authored
Merge pull request #50 from codeflash-ai/remove-duplicate-preexisting-objects
changed preexisting objects to be a set. removes duplicates naturally…
2 parents a651fbf + 4aa194e commit a0afba2

File tree

5 files changed

+32
-32
lines changed

5 files changed

+32
-32
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,20 +235,20 @@ def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str |
235235
return edited_code, contextual_dunder_methods
236236

237237

238-
def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
239-
"""Find all preexisting functions, classes or class methods in the source code"""
240-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
238+
def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]:
239+
"""Find all preexisting functions, classes or class methods in the source code."""
240+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set()
241241
try:
242242
module_node: ast.Module = ast.parse(source_code)
243243
except SyntaxError:
244244
logger.exception("find_preexisting_objects - Syntax error while parsing code")
245245
return preexisting_objects
246246
for node in module_node.body:
247247
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
248-
preexisting_objects.append((node.name, []))
248+
preexisting_objects.add((node.name, ()))
249249
elif isinstance(node, ast.ClassDef):
250-
preexisting_objects.append((node.name, []))
250+
preexisting_objects.add((node.name, ()))
251251
for cnode in node.body:
252252
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
253-
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
253+
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
254254
return preexisting_objects

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
3838

3939
def __init__(
4040
self,
41-
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
41+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
4242
function_names: set[tuple[str | None, str]] | None = None,
4343
) -> None:
4444
super().__init__()
45-
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else []
45+
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()
4646

4747
self.function_names = function_names # set of (class_name, function_name)
4848
self.modified_functions: dict[
@@ -60,7 +60,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
6060
self.modified_init_functions[self.current_class] = node
6161
elif (
6262
self.preexisting_objects
63-
and (node.name.value, []) not in self.preexisting_objects
63+
and (node.name.value, ()) not in self.preexisting_objects
6464
and self.current_class is None
6565
):
6666
self.new_functions.append(node)
@@ -71,7 +71,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
7171
return False # If already in a class, do not recurse deeper
7272
self.current_class = node.name.value
7373

74-
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
74+
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
7575
for child_node in node.body.body:
7676
if (
7777
self.preexisting_objects
@@ -159,7 +159,7 @@ def replace_functions_in_file(
159159
source_code: str,
160160
original_function_names: list[str],
161161
optimized_code: str,
162-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
162+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
163163
) -> str:
164164
parsed_function_names = []
165165
for original_function_name in original_function_names:
@@ -195,7 +195,7 @@ def replace_functions_and_add_imports(
195195
function_names: list[str],
196196
optimized_code: str,
197197
module_abspath: Path,
198-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
198+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
199199
project_root_path: Path,
200200
) -> str:
201201
return add_needed_imports_from_module(
@@ -211,7 +211,7 @@ def replace_function_definitions_in_module(
211211
function_names: list[str],
212212
optimized_code: str,
213213
module_abspath: Path,
214-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
214+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
215215
project_root_path: Path,
216216
) -> bool:
217217
source_code: str = module_abspath.read_text(encoding="utf8")

codeflash/context/code_context_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def get_code_optimization_context(
6565
if final_read_writable_tokens > optim_token_limit:
6666
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
6767

68-
# Setup preexisting objects for code replacer TODO: should remove duplicates
69-
preexisting_objects = list(
68+
# Setup preexisting objects for code replacer
69+
preexisting_objects = set(
7070
chain(
7171
find_preexisting_objects(final_read_writable_code),
7272
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class CodeOptimizationContext(BaseModel):
9999
read_writable_code: str = Field(min_length=1)
100100
read_only_context_code: str = ""
101101
helper_functions: list[FunctionSource]
102-
preexisting_objects: list[tuple[str, list[FunctionParent]]]
102+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]]
103103

104104
class CodeContextType(str, Enum):
105105
READ_WRITABLE = "READ_WRITABLE"

tests/test_code_replacement.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def totally_new_function(value):
7474
"""
7575

7676
function_name: str = "NewClass.new_function"
77-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
77+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
7878
new_code: str = replace_functions_and_add_imports(
7979
source_code=original_code,
8080
function_names=[function_name],
@@ -135,7 +135,7 @@ def other_function(st):
135135
"""
136136

137137
function_name: str = "NewClass.new_function"
138-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
138+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
139139
new_code: str = replace_functions_and_add_imports(
140140
source_code=original_code,
141141
function_names=[function_name],
@@ -196,7 +196,7 @@ def totally_new_function(value):
196196
"""
197197

198198
function_names: list[str] = ["other_function"]
199-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
199+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
200200
new_code: str = replace_functions_and_add_imports(
201201
source_code=original_code,
202202
function_names=function_names,
@@ -260,7 +260,7 @@ def totally_new_function(value):
260260
"""
261261

262262
function_names: list[str] = ["yet_another_function", "other_function"]
263-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
263+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
264264
new_code: str = replace_functions_and_add_imports(
265265
source_code=original_code,
266266
function_names=function_names,
@@ -313,7 +313,7 @@ def supersort(doink):
313313
"""
314314

315315
function_names: list[str] = ["sorter_deps"]
316-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
316+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
317317
new_code: str = replace_functions_and_add_imports(
318318
source_code=original_code,
319319
function_names=function_names,
@@ -388,7 +388,7 @@ def blab(st):
388388
389389
print("Not cool")
390390
"""
391-
preexisting_objects = find_preexisting_objects(original_code_main) + find_preexisting_objects(original_code_helper)
391+
preexisting_objects = find_preexisting_objects(original_code_main) | find_preexisting_objects(original_code_helper)
392392
new_main_code: str = replace_functions_and_add_imports(
393393
source_code=original_code_main,
394394
function_names=["other_function"],
@@ -591,7 +591,7 @@ def from_config(config: Optional[dict[str, Any]]):
591591
)
592592
"""
593593
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
594-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
594+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
595595

596596
new_code: str = replace_functions_and_add_imports(
597597
source_code=original_code,
@@ -662,7 +662,7 @@ def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
662662
return np.sum(a != b) / a.size
663663
'''
664664
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
665-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
665+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
666666
new_code: str = replace_functions_and_add_imports(
667667
source_code=original_code,
668668
function_names=function_names,
@@ -715,7 +715,7 @@ def totally_new_function(value: Optional[str]):
715715
print("Hello world")
716716
"""
717717
function_name: str = "NewClass.__init__"
718-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
718+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
719719
new_code: str = replace_functions_and_add_imports(
720720
source_code=original_code,
721721
function_names=[function_name],
@@ -814,8 +814,8 @@ def real_bar(self) -> int:
814814
'''
815815

816816
function_name: str = "Fu.foo"
817-
parents = [FunctionParent("Fu", "ClassDef")]
818-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
817+
parents = (FunctionParent("Fu", "ClassDef"),)
818+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)}
819819
new_code: str = replace_functions_in_file(
820820
source_code=original_code,
821821
original_function_names=[function_name],
@@ -854,7 +854,7 @@ def real_bar(self) -> int:
854854
pass
855855
'''
856856

857-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
857+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
858858
new_code: str = replace_functions_in_file(
859859
source_code=original_code,
860860
original_function_names=["Fu.real_bar"],
@@ -891,7 +891,7 @@ def __call__(self, value):
891891
"""
892892

893893
function_names: list[str] = ["yet_another_function", "other_function"]
894-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
894+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
895895
new_code: str = replace_functions_and_add_imports(
896896
source_code=original_code,
897897
function_names=function_names,
@@ -1278,7 +1278,7 @@ def cosine_similarity_top_k(
12781278
12791279
return ret_idxs, scores
12801280
'''
1281-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
1281+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
12821282

12831283
helper_functions = [
12841284
FakeFunctionSource(
@@ -1579,7 +1579,7 @@ def nested_function(self):
15791579
"NewClass.new_function2",
15801580
"NestedClass.nested_function",
15811581
] # Nested classes should be ignored, even if provided as target
1582-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
1582+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
15831583
new_code: str = replace_functions_and_add_imports(
15841584
source_code=original_code,
15851585
function_names=function_names,
@@ -1615,7 +1615,7 @@ def new_function2(value):
16151615
"""
16161616

16171617
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
1618-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
1618+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
16191619
new_code: str = replace_functions_and_add_imports(
16201620
source_code=original_code,
16211621
function_names=function_names,

0 commit comments

Comments
 (0)