Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion codeflash/api/cfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any, Optional

import requests
import sentry_sdk
from pydantic.json import pydantic_encoder

from codeflash.cli_cmds.console import console, logger
Expand Down Expand Up @@ -194,7 +195,8 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
req.raise_for_status()
content: dict[str, list[str]] = req.json()
except Exception as e:
logger.error(f"Error getting blocklisted functions: {e}", exc_info=True)
logger.error(f"Error getting blocklisted functions: {e}")
sentry_sdk.capture_exception(e)
return {}

return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
24 changes: 16 additions & 8 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,22 @@ def process_pyproject_config(args: Namespace) -> Namespace:
assert args.tests_root is not None, "--tests-root must be specified"
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"

assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), (
"Codeflash API key not found. When running in a Github Actions Context, provide the "
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
f"Here's a direct link: {get_github_secrets_page_url()}\n"
"Exiting..."
)
if env_utils.get_pr_number() is not None:
assert env_utils.ensure_codeflash_api_key(), (
"Codeflash API key not found. When running in a Github Actions Context, provide the "
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
f"Here's a direct link: {get_github_secrets_page_url()}\n"
"Exiting..."
)

repo = git.Repo(search_parent_directories=True)

owner, repo_name = get_repo_owner_and_name(repo)

require_github_app_or_exit(owner, repo_name)

if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
normalized_ignore_paths = []
for path in args.ignore_paths:
Expand Down
2 changes: 1 addition & 1 deletion codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def check_for_toml_or_setup_file() -> str | None:
return cast(str, project_name)


def install_github_actions(override_formatter_check: bool=False) -> None:
def install_github_actions(override_formatter_check: bool = False) -> None:
try:
config, config_file_path = parse_config_file(override_formatter_check=override_formatter_check)

Expand Down
3 changes: 2 additions & 1 deletion codeflash/cli_cmds/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
)

logger = logging.getLogger("rich")
logging.getLogger('parso').setLevel(logging.WARNING)
logging.getLogger("parso").setLevel(logging.WARNING)


def paneled_text(
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
Expand Down
6 changes: 3 additions & 3 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
optimized_code: str,
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
Expand Down Expand Up @@ -195,7 +195,7 @@ def replace_functions_and_add_imports(
function_names: list[str],
optimized_code: str,
module_abspath: Path,
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
project_root_path: Path,
) -> str:
return add_needed_imports_from_module(
Expand All @@ -211,7 +211,7 @@ def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
module_abspath: Path,
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
project_root_path: Path,
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
Expand Down
4 changes: 3 additions & 1 deletion codeflash/code_utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
raise ValueError(msg)


def parse_config_file(config_file_path: Path | None = None, override_formatter_check: bool=False) -> tuple[dict[str, Any], Path]:
def parse_config_file(
config_file_path: Path | None = None, override_formatter_check: bool = False
) -> tuple[dict[str, Any], Path]:
config_file_path = find_pyproject_toml(config_file_path)
try:
with config_file_path.open("rb") as f:
Expand Down
115 changes: 74 additions & 41 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,43 @@


def get_code_optimization_context(
function_to_optimize: FunctionToOptimize, project_root_path: Path, optim_token_limit: int = 8000, testgen_token_limit: int = 8000
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
optim_token_limit: int = 8000,
testgen_token_limit: int = 8000,
) -> CodeOptimizationContext:
# Get FunctionSource representation of helpers of FTO
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi({function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path)
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
)

# Add function to optimize into helpers of FTO dict, as they'll be processed together
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source)

# Format data to search for helpers of helpers using get_function_sources_from_jedi
helpers_of_fto_qualified_names_dict = {
file_path: {source.qualified_name for source in sources}
for file_path, sources in helpers_of_fto_dict.items()
file_path: {source.qualified_name for source in sources} for file_path, sources in helpers_of_fto_dict.items()
}

# __init__ functions are automatically considered as helpers of FTO, so we add them to the dict (regardless of whether they exist)
# This helps us to search for helpers of __init__ functions of classes that contain helpers of FTO
for qualified_names in helpers_of_fto_qualified_names_dict.values():
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if '.' in qn})
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn})

# Get FunctionSource representation of helpers of helpers of FTO
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(helpers_of_fto_qualified_names_dict, project_root_path)
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(
helpers_of_fto_qualified_names_dict, project_root_path
)

# Extract code context for optimization
final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto_dict,{}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code
final_read_writable_code = extract_code_string_context_from_files(
helpers_of_fto_dict,
{},
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.READ_WRITABLE,
).code
read_only_code_markdown = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
Expand Down Expand Up @@ -80,10 +92,7 @@ def get_code_optimization_context(
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
# Extract read only code without docstrings
read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True
)
read_only_context_code = read_only_code_no_docstring_markdown.markdown
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code))
Expand Down Expand Up @@ -116,13 +125,14 @@ def get_code_optimization_context(
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")

return CodeOptimizationContext(
testgen_context_code = testgen_context_code,
testgen_context_code=testgen_context_code,
read_writable_code=final_read_writable_code,
read_only_context_code=read_only_context_code,
helper_functions=helpers_of_fto_list,
preexisting_objects=preexisting_objects,
)


def extract_code_string_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
Expand Down Expand Up @@ -169,9 +179,15 @@ def extract_code_string_context_from_files(
continue
try:
qualified_function_names = {func.qualified_name for func in function_sources}
helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())}
helpers_of_helpers_qualified_names = {
func.qualified_name for func in helpers_of_helpers.get(file_path, set())
}
code_context = parse_code_and_prune_cst(
original_code, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings
original_code,
code_context_type,
qualified_function_names,
helpers_of_helpers_qualified_names,
remove_docstrings,
)

except ValueError as e:
Expand All @@ -180,12 +196,12 @@ def extract_code_string_context_from_files(
if code_context.strip():
final_code_string_context += f"\n{code_context}"
final_code_string_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=final_code_string_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions= list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()))
src_module_code=original_code,
dst_module_code=final_code_string_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())),
)
if code_context_type == CodeContextType.READ_WRITABLE:
return CodeString(code=final_code_string_context)
Expand All @@ -199,7 +215,7 @@ def extract_code_string_context_from_files(
try:
qualified_helper_function_names = {func.qualified_name for func in helper_function_sources}
code_context = parse_code_and_prune_cst(
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
Expand All @@ -208,15 +224,16 @@ def extract_code_string_context_from_files(
if code_context.strip():
final_code_string_context += f"\n{code_context}"
final_code_string_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=final_code_string_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
src_module_code=original_code,
dst_module_code=final_code_string_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
)
return CodeString(code=final_code_string_context)


def extract_code_markdown_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
Expand Down Expand Up @@ -263,9 +280,15 @@ def extract_code_markdown_context_from_files(
continue
try:
qualified_function_names = {func.qualified_name for func in function_sources}
helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())}
helpers_of_helpers_qualified_names = {
func.qualified_name for func in helpers_of_helpers.get(file_path, set())
}
code_context = parse_code_and_prune_cst(
original_code, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings
original_code,
code_context_type,
qualified_function_names,
helpers_of_helpers_qualified_names,
remove_docstrings,
)

except ValueError as e:
Expand All @@ -280,7 +303,8 @@ def extract_code_markdown_context_from_files(
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()))
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())
),
),
file_path=file_path.relative_to(project_root_path),
)
Expand All @@ -295,7 +319,7 @@ def extract_code_markdown_context_from_files(
try:
qualified_helper_function_names = {func.qualified_name for func in helper_function_sources}
code_context = parse_code_and_prune_cst(
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings,
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
Expand All @@ -317,8 +341,9 @@ def extract_code_markdown_context_from_files(
return code_context_markdown


def get_function_to_optimize_as_function_source(function_to_optimize: FunctionToOptimize,
project_root_path: Path) -> FunctionSource:
def get_function_to_optimize_as_function_source(
function_to_optimize: FunctionToOptimize, project_root_path: Path
) -> FunctionSource:
# Use jedi to find function to optimize
script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path))

Expand All @@ -327,11 +352,12 @@ def get_function_to_optimize_as_function_source(function_to_optimize: FunctionTo

# Find the name that matches our function
for name in names:
if (name.type == "function" and
name.full_name and
name.name == function_to_optimize.function_name and
get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name):

if (
name.type == "function"
and name.full_name
and name.name == function_to_optimize.function_name
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
):
function_source = FunctionSource(
file_path=function_to_optimize.file_path,
qualified_name=function_to_optimize.qualified_name,
Expand All @@ -343,7 +369,8 @@ def get_function_to_optimize_as_function_source(function_to_optimize: FunctionTo
return function_source

raise ValueError(
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}")
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
)


def get_function_sources_from_jedi(
Expand Down Expand Up @@ -417,8 +444,13 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode
return indented_block.with_changes(body=indented_block.body[1:])
return indented_block


def parse_code_and_prune_cst(
code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = set(), remove_docstrings: bool = False
code: str,
code_context_type: CodeContextType,
target_functions: set[str],
helpers_of_helper_functions: set[str] = set(),
remove_docstrings: bool = False,
) -> str:
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
module = cst.parse_module(code)
Expand All @@ -441,6 +473,7 @@ def parse_code_and_prune_cst(
return str(filtered_node.code)
return ""


def prune_cst_for_read_writable_code(
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
Expand Down Expand Up @@ -520,6 +553,7 @@ def prune_cst_for_read_writable_code(

return (node.with_changes(**updates) if updates else node), True


def prune_cst_for_read_only_code(
node: cst.CSTNode,
target_functions: set[str],
Expand Down Expand Up @@ -624,7 +658,6 @@ def prune_cst_for_read_only_code(
return None, False



def prune_cst_for_testgen_code(
node: cst.CSTNode,
target_functions: set[str],
Expand Down
3 changes: 2 additions & 1 deletion codeflash/discovery/pytest_new_process_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def pytest_collection_modifyitems(config, items):
if "benchmark" in item.fixturenames:
item.add_marker(skip_benchmark)


def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
test_results = []
for test in pytest_tests:
Expand All @@ -39,7 +40,7 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s

try:
exitcode = pytest.main(
[tests_root, "-p no:logging", "--collect-only", "-m", "not skip",], plugins=[PytestCollectionPlugin()]
[tests_root, "-p no:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
)
except Exception as e: # noqa: BLE001
print(f"Failed to collect tests: {e!s}") # noqa: T201
Expand Down
Loading
Loading