Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
from contextlib import contextmanager
from functools import lru_cache
from importlib.util import find_spec
from pathlib import Path
from tempfile import TemporaryDirectory

Expand Down Expand Up @@ -355,6 +356,56 @@ def module_name_from_file_path(file_path: Path, project_root_path: Path, *, trav
raise ValueError(msg) # noqa: B904


def validate_module_import(module_path: str, project_root: Path) -> tuple[bool, str]:
"""Check if a module is importable using find_spec (no actual import or subprocess).

Returns (success, error_message). Uses importlib.util.find_spec to check
module availability without triggering module initialization.
"""
project_root_str = str(project_root)
added = False
if project_root_str not in sys.path:
sys.path.insert(0, project_root_str)
added = True
try:
if find_spec(module_path) is not None:
return True, ""
return False, f"Module '{module_path}' not found (find_spec returned None)"
except ModuleNotFoundError as e:
return False, str(e)
except Exception as e:
return False, f"Error checking module '{module_path}': {e}"
finally:
if added:
sys.path.remove(project_root_str)


def infer_module_root_from_file(file_path: Path, pyproject_dir: Path) -> Path | None:
"""Infer the correct module-root for a Python file by walking the __init__.py chain.

Walks up from the file's parent directory toward pyproject_dir, tracking the
topmost directory that contains ``__init__.py`` (i.e. the top-level package).
The module-root is this top-level package directory, since
``project_root_from_module_root`` will use its parent as the PYTHONPATH entry.

Returns the inferred module-root path, or None if inference fails.
"""
file_path = file_path.resolve()
pyproject_dir = pyproject_dir.resolve()
current = file_path.parent
top_package: Path | None = None
while current not in (pyproject_dir, current.parent):
if (current / "__init__.py").exists():
top_package = current
else:
break
current = current.parent
if top_package is not None:
return top_package
# No __init__.py found — treat the file's own directory as the module-root
return file_path.parent


def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:
"""Get file path from module path."""
return project_root_path / (module_name.replace(".", os.sep) + ".py")
Expand Down
69 changes: 69 additions & 0 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import TYPE_CHECKING, Callable

import libcst as cst
import tomlkit
from git import Repo as GitRepo
from rich.console import Group
from rich.panel import Panel
Expand All @@ -23,6 +24,7 @@
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.cli import project_root_from_module_root
from codeflash.cli_cmds.console import (
code_print,
console,
Expand All @@ -42,9 +44,12 @@
extract_unique_errors,
file_name_from_test_module_name,
get_run_tmp_file,
infer_module_root_from_file,
module_name_from_file_path,
normalize_by_max,
restore_conftest,
unified_diff_strings,
validate_module_import,
)
from codeflash.code_utils.config_consts import (
COVERAGE_THRESHOLD,
Expand All @@ -58,6 +63,7 @@
EffortLevel,
get_effort_value,
)
from codeflash.code_utils.config_parser import find_pyproject_toml
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
from codeflash.code_utils.git_utils import git_root_dir
Expand Down Expand Up @@ -541,6 +547,56 @@ def parse_line_profile_test_results(

# --- End hooks ---

def try_correct_module_root(self) -> bool:
"""Try to infer and apply the correct module-root if the current one is wrong.

Walks the __init__.py chain to determine the correct module-root, validates
it by trying an import, and updates pyproject.toml + in-memory config on success.
"""
try:
pyproject_path = find_pyproject_toml(None)
except ValueError:
return False

pyproject_dir = pyproject_path.parent
inferred_root = infer_module_root_from_file(self.function_to_optimize.file_path, pyproject_dir)
if inferred_root is None or inferred_root.resolve() == self.args.module_root.resolve():
return False

new_module_root = inferred_root.resolve()
new_project_root = project_root_from_module_root(new_module_root, pyproject_path)
try:
new_module_path = module_name_from_file_path(self.function_to_optimize.file_path, new_project_root)
except ValueError:
return False

import_ok, _ = validate_module_import(new_module_path, new_project_root)
if not import_ok:
return False

# Import succeeded with the inferred module-root — update pyproject.toml
try:
with pyproject_path.open("rb") as f:
data = tomlkit.parse(f.read())
relative_root = os.path.relpath(new_module_root, pyproject_dir)
data["tool"]["codeflash"]["module-root"] = relative_root # type: ignore[index]
with pyproject_path.open("w", encoding="utf-8") as f:
f.write(tomlkit.dumps(data))
except Exception:
logger.debug("Failed to update pyproject.toml with corrected module-root")
return False

# Update in-memory config
self.args.module_root = new_module_root
self.args.project_root = new_project_root
self.project_root = new_project_root.resolve()
self.original_module_path = new_module_path

logger.info(
f"Auto-corrected module-root to '{os.path.relpath(new_module_root, pyproject_dir)}' in pyproject.toml"
)
return True

def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
should_run_experiment = self.experiment_id is not None
logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}")
Expand All @@ -555,6 +611,19 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P
f"Cannot optimize without tests when --no-gen-tests is set."
)

# Pre-flight: verify module-root consistency and importability before expensive API calls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of implementing it in the function optimizer here, can you find a good spot for it in https://github.com/codeflash-ai/codeflash/blob/main/codeflash/languages/python/function_optimizer.py?

Copy link
Collaborator

@KRRT7 KRRT7 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if pre-flight doesn't exist in the language base add it

if self.function_to_optimize.language == "python":
# Auto-correct module-root if it doesn't match the inferred root from __init__.py chain
self.try_correct_module_root()
# Now validate the (possibly corrected) module can actually be imported
import_ok, import_error = validate_module_import(self.original_module_path, self.project_root)
if not import_ok:
return Failure(
f"Cannot import module '{self.original_module_path}': {import_error}\n"
"This prevents test execution. Please check that all dependencies are installed "
"and that 'module-root' is correctly configured in pyproject.toml."
)

self.cleanup_leftover_test_return_values()
file_name_from_test_module_name.cache_clear()
ctx_result = self.get_code_optimization_context()
Expand Down
Loading