Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,20 @@ def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str |
return edited_code, contextual_dunder_methods


def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
"""Find all preexisting functions, classes or class methods in the source code"""
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]:
"""Find all preexisting functions, classes or class methods in the source code."""
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set()
try:
module_node: ast.Module = ast.parse(source_code)
except SyntaxError:
logger.exception("find_preexisting_objects - Syntax error while parsing code")
return preexisting_objects
for node in module_node.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
preexisting_objects.append((node.name, []))
preexisting_objects.add((node.name, ()))
elif isinstance(node, ast.ClassDef):
preexisting_objects.append((node.name, []))
preexisting_objects.add((node.name, ()))
for cnode in node.body:
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
return preexisting_objects
14 changes: 7 additions & 7 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class OptimFunctionCollector(cst.CSTVisitor):

def __init__(
self,
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
function_names: set[tuple[str | None, str]] | None = None,
) -> None:
super().__init__()
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else []
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()

self.function_names = function_names # set of (class_name, function_name)
self.modified_functions: dict[
Expand All @@ -60,7 +60,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
self.modified_init_functions[self.current_class] = node
elif (
self.preexisting_objects
and (node.name.value, []) not in self.preexisting_objects
and (node.name.value, ()) not in self.preexisting_objects
and self.current_class is None
):
self.new_functions.append(node)
Expand All @@ -71,7 +71,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
return False # If already in a class, do not recurse deeper
self.current_class = node.name.value

parents = [FunctionParent(name=node.name.value, type="ClassDef")]
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
for child_node in node.body.body:
if (
self.preexisting_objects
Expand Down Expand Up @@ -159,7 +159,7 @@ def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
optimized_code: str,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
Expand Down Expand Up @@ -195,7 +195,7 @@ def replace_functions_and_add_imports(
function_names: list[str],
optimized_code: str,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
project_root_path: Path,
) -> str:
return add_needed_imports_from_module(
Expand All @@ -211,7 +211,7 @@ def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
project_root_path: Path,
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
Expand Down
4 changes: 2 additions & 2 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def get_code_optimization_context(
if final_read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")

# Setup preexisting objects for code replacer TODO: should remove duplicates
preexisting_objects = list(
# Setup preexisting objects for code replacer
preexisting_objects = set(
chain(
find_preexisting_objects(final_read_writable_code),
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
Expand Down
2 changes: 1 addition & 1 deletion codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class CodeOptimizationContext(BaseModel):
read_writable_code: str = Field(min_length=1)
read_only_context_code: str = ""
helper_functions: list[FunctionSource]
preexisting_objects: list[tuple[str, list[FunctionParent]]]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]]

class CodeContextType(str, Enum):
READ_WRITABLE = "READ_WRITABLE"
Expand Down
32 changes: 16 additions & 16 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def totally_new_function(value):
"""

function_name: str = "NewClass.new_function"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
Expand Down Expand Up @@ -135,7 +135,7 @@ def other_function(st):
"""

function_name: str = "NewClass.new_function"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
Expand Down Expand Up @@ -196,7 +196,7 @@ def totally_new_function(value):
"""

function_names: list[str] = ["other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
Expand Down Expand Up @@ -260,7 +260,7 @@ def totally_new_function(value):
"""

function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
Expand Down Expand Up @@ -313,7 +313,7 @@ def supersort(doink):
"""

function_names: list[str] = ["sorter_deps"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
Expand Down Expand Up @@ -388,7 +388,7 @@ def blab(st):

print("Not cool")
"""
preexisting_objects = find_preexisting_objects(original_code_main) + find_preexisting_objects(original_code_helper)
preexisting_objects = find_preexisting_objects(original_code_main) | find_preexisting_objects(original_code_helper)
new_main_code: str = replace_functions_and_add_imports(
source_code=original_code_main,
function_names=["other_function"],
Expand Down Expand Up @@ -591,7 +591,7 @@ def from_config(config: Optional[dict[str, Any]]):
)
"""
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)

new_code: str = replace_functions_and_add_imports(
source_code=original_code,
Expand Down Expand Up @@ -662,7 +662,7 @@ def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
return np.sum(a != b) / a.size
'''
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
Expand Down Expand Up @@ -715,7 +715,7 @@ def totally_new_function(value: Optional[str]):
print("Hello world")
"""
function_name: str = "NewClass.__init__"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
Expand Down Expand Up @@ -814,8 +814,8 @@ def real_bar(self) -> int:
'''

function_name: str = "Fu.foo"
parents = [FunctionParent("Fu", "ClassDef")]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
parents = (FunctionParent("Fu", "ClassDef"),)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)}
new_code: str = replace_functions_in_file(
source_code=original_code,
original_function_names=[function_name],
Expand Down Expand Up @@ -854,7 +854,7 @@ def real_bar(self) -> int:
pass
'''

preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
new_code: str = replace_functions_in_file(
source_code=original_code,
original_function_names=["Fu.real_bar"],
Expand Down Expand Up @@ -891,7 +891,7 @@ def __call__(self, value):
"""

function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
Expand Down Expand Up @@ -1278,7 +1278,7 @@ def cosine_similarity_top_k(

return ret_idxs, scores
'''
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)

helper_functions = [
FakeFunctionSource(
Expand Down Expand Up @@ -1579,7 +1579,7 @@ def nested_function(self):
"NewClass.new_function2",
"NestedClass.nested_function",
] # Nested classes should be ignored, even if provided as target
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
Expand Down Expand Up @@ -1615,7 +1615,7 @@ def new_function2(value):
"""

function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
Expand Down
Loading