Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jedi.api.classes import Name
from libcst import CSTNode

from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.console import code_print, logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
Expand Down Expand Up @@ -73,6 +73,7 @@ def get_code_optimization_context(

# Handle token limits
tokenizer = tiktoken.encoding_for_model("gpt-4o")
code_print(final_read_writable_code)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to print this to the output?

final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
if final_read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
Expand Down Expand Up @@ -356,6 +357,7 @@ def get_function_to_optimize_as_function_source(
name.type == "function"
and name.full_name
and name.name == function_to_optimize.function_name
and name.full_name.startswith(name.module_name)
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
):
function_source = FunctionSource(
Expand Down Expand Up @@ -410,6 +412,7 @@ def get_function_sources_from_jedi(
and definition.full_name
and definition.type == "function"
and not belongs_to_function_qualified(definition, qualified_function_name)
and definition.full_name.startswith(definition.module_name)
# Avoid nested functions or classes. Only class.function is allowed
and len((qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split(".")) <= 2
):
Expand Down
2 changes: 1 addition & 1 deletion codeflash/optimization/function_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def belongs_to_class(name: Name, class_name: str) -> bool:
def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool:
"""Check if the given jedi Name is a direct child of the specified function, matched by qualified function name."""
try:
if get_qualified_name(name.module_name, name.full_name) == qualified_function_name:
if name.full_name.startswith(name.module_name) and get_qualified_name(name.module_name, name.full_name) == qualified_function_name:
# Handles function definition and recursive function calls
return False
if name := name.parent():
Expand Down
Loading