Skip to content

Commit 274f98b

Browse files
committed
changed preexisting objects to be a set. removes duplicates naturally and makes it easier to search for matches when replacing code.
1 parent a651fbf commit 274f98b

File tree

4 files changed

+45
-27
lines changed

4 files changed

+45
-27
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str |
235235
return edited_code, contextual_dunder_methods
236236

237237

238-
def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
238+
def find_preexisting_object_old(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
239239
"""Find all preexisting functions, classes or class methods in the source code"""
240240
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
241241
try:
@@ -252,3 +252,21 @@ def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionP
252252
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
253253
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
254254
return preexisting_objects
255+
256+
def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]:
257+
"""Find all preexisting functions, classes or class methods in the source code"""
258+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set()
259+
try:
260+
module_node: ast.Module = ast.parse(source_code)
261+
except SyntaxError:
262+
logger.exception("find_preexisting_objects - Syntax error while parsing code")
263+
return preexisting_objects
264+
for node in module_node.body:
265+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
266+
preexisting_objects.add((node.name, ()))
267+
elif isinstance(node, ast.ClassDef):
268+
preexisting_objects.add((node.name, ()))
269+
for cnode in node.body:
270+
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
271+
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
272+
return preexisting_objects

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
3838

3939
def __init__(
4040
self,
41-
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
41+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
4242
function_names: set[tuple[str | None, str]] | None = None,
4343
) -> None:
4444
super().__init__()
45-
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else []
45+
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()
4646

4747
self.function_names = function_names # set of (class_name, function_name)
4848
self.modified_functions: dict[
@@ -60,7 +60,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
6060
self.modified_init_functions[self.current_class] = node
6161
elif (
6262
self.preexisting_objects
63-
and (node.name.value, []) not in self.preexisting_objects
63+
and (node.name.value, ()) not in self.preexisting_objects
6464
and self.current_class is None
6565
):
6666
self.new_functions.append(node)
@@ -71,7 +71,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
7171
return False # If already in a class, do not recurse deeper
7272
self.current_class = node.name.value
7373

74-
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
74+
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
7575
for child_node in node.body.body:
7676
if (
7777
self.preexisting_objects
@@ -159,7 +159,7 @@ def replace_functions_in_file(
159159
source_code: str,
160160
original_function_names: list[str],
161161
optimized_code: str,
162-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
162+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
163163
) -> str:
164164
parsed_function_names = []
165165
for original_function_name in original_function_names:
@@ -195,7 +195,7 @@ def replace_functions_and_add_imports(
195195
function_names: list[str],
196196
optimized_code: str,
197197
module_abspath: Path,
198-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
198+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
199199
project_root_path: Path,
200200
) -> str:
201201
return add_needed_imports_from_module(
@@ -211,7 +211,7 @@ def replace_function_definitions_in_module(
211211
function_names: list[str],
212212
optimized_code: str,
213213
module_abspath: Path,
214-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
214+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
215215
project_root_path: Path,
216216
) -> bool:
217217
source_code: str = module_abspath.read_text(encoding="utf8")

codeflash/context/code_context_extractor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ def get_code_optimization_context(
6565
if final_read_writable_tokens > optim_token_limit:
6666
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
6767

68-
# Setup preexisting objects for code replacer TODO: should remove duplicates
69-
preexisting_objects = list(
68+
# Setup preexisting objects for code replacer
69+
preexisting_objects = list(set(
7070
chain(
7171
find_preexisting_objects(final_read_writable_code),
7272
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
7373
)
74-
)
74+
))
7575
read_only_context_code = read_only_code_markdown.markdown
7676

7777
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code))

tests/test_code_replacement.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def totally_new_function(value):
7474
"""
7575

7676
function_name: str = "NewClass.new_function"
77-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
77+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
7878
new_code: str = replace_functions_and_add_imports(
7979
source_code=original_code,
8080
function_names=[function_name],
@@ -135,7 +135,7 @@ def other_function(st):
135135
"""
136136

137137
function_name: str = "NewClass.new_function"
138-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
138+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
139139
new_code: str = replace_functions_and_add_imports(
140140
source_code=original_code,
141141
function_names=[function_name],
@@ -196,7 +196,7 @@ def totally_new_function(value):
196196
"""
197197

198198
function_names: list[str] = ["other_function"]
199-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
199+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
200200
new_code: str = replace_functions_and_add_imports(
201201
source_code=original_code,
202202
function_names=function_names,
@@ -260,7 +260,7 @@ def totally_new_function(value):
260260
"""
261261

262262
function_names: list[str] = ["yet_another_function", "other_function"]
263-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
263+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
264264
new_code: str = replace_functions_and_add_imports(
265265
source_code=original_code,
266266
function_names=function_names,
@@ -313,7 +313,7 @@ def supersort(doink):
313313
"""
314314

315315
function_names: list[str] = ["sorter_deps"]
316-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
316+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
317317
new_code: str = replace_functions_and_add_imports(
318318
source_code=original_code,
319319
function_names=function_names,
@@ -388,7 +388,7 @@ def blab(st):
388388
389389
print("Not cool")
390390
"""
391-
preexisting_objects = find_preexisting_objects(original_code_main) + find_preexisting_objects(original_code_helper)
391+
preexisting_objects = find_preexisting_objects(original_code_main) | find_preexisting_objects(original_code_helper)
392392
new_main_code: str = replace_functions_and_add_imports(
393393
source_code=original_code_main,
394394
function_names=["other_function"],
@@ -591,7 +591,7 @@ def from_config(config: Optional[dict[str, Any]]):
591591
)
592592
"""
593593
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
594-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
594+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
595595

596596
new_code: str = replace_functions_and_add_imports(
597597
source_code=original_code,
@@ -662,7 +662,7 @@ def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
662662
return np.sum(a != b) / a.size
663663
'''
664664
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
665-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
665+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
666666
new_code: str = replace_functions_and_add_imports(
667667
source_code=original_code,
668668
function_names=function_names,
@@ -715,7 +715,7 @@ def totally_new_function(value: Optional[str]):
715715
print("Hello world")
716716
"""
717717
function_name: str = "NewClass.__init__"
718-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
718+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
719719
new_code: str = replace_functions_and_add_imports(
720720
source_code=original_code,
721721
function_names=[function_name],
@@ -814,8 +814,8 @@ def real_bar(self) -> int:
814814
'''
815815

816816
function_name: str = "Fu.foo"
817-
parents = [FunctionParent("Fu", "ClassDef")]
818-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
817+
parents = (FunctionParent("Fu", "ClassDef"),)
818+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)}
819819
new_code: str = replace_functions_in_file(
820820
source_code=original_code,
821821
original_function_names=[function_name],
@@ -854,7 +854,7 @@ def real_bar(self) -> int:
854854
pass
855855
'''
856856

857-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
857+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
858858
new_code: str = replace_functions_in_file(
859859
source_code=original_code,
860860
original_function_names=["Fu.real_bar"],
@@ -891,7 +891,7 @@ def __call__(self, value):
891891
"""
892892

893893
function_names: list[str] = ["yet_another_function", "other_function"]
894-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
894+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
895895
new_code: str = replace_functions_and_add_imports(
896896
source_code=original_code,
897897
function_names=function_names,
@@ -1278,7 +1278,7 @@ def cosine_similarity_top_k(
12781278
12791279
return ret_idxs, scores
12801280
'''
1281-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
1281+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
12821282

12831283
helper_functions = [
12841284
FakeFunctionSource(
@@ -1579,7 +1579,7 @@ def nested_function(self):
15791579
"NewClass.new_function2",
15801580
"NestedClass.nested_function",
15811581
] # Nested classes should be ignored, even if provided as target
1582-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
1582+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
15831583
new_code: str = replace_functions_and_add_imports(
15841584
source_code=original_code,
15851585
function_names=function_names,
@@ -1615,7 +1615,7 @@ def new_function2(value):
16151615
"""
16161616

16171617
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
1618-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
1618+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
16191619
new_code: str = replace_functions_and_add_imports(
16201620
source_code=original_code,
16211621
function_names=function_names,

0 commit comments

Comments
 (0)