Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions codeflash/code_utils/config_js_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""JavaScript/TypeScript module resolution validation utilities."""

from __future__ import annotations

from typing import TYPE_CHECKING

from codeflash.code_utils.config_js import detect_module_root, get_package_json_data
from codeflash.languages.javascript.test_runner import find_node_project_root

if TYPE_CHECKING:
from pathlib import Path


def validate_js_module_resolution(source_file: Path, project_root: Path, module_root: Path) -> tuple[bool, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best to place and dedupe with logic present in file 'init_javascript.py'

"""Validate that a JS/TS source file can be resolved within the configured module root.

Checks:
1. Source file exists
2. Source file is within project_root
3. package.json exists in project_root
4. Source file is within module_root

Returns:
(True, "") on success, (False, error_message) on failure.

"""
source_file = source_file.resolve()
project_root = project_root.resolve()
module_root = module_root.resolve()

if not source_file.exists():
return False, f"Source file does not exist: {source_file}"

try:
source_file.relative_to(project_root)
except ValueError:
return False, f"Source file {source_file} is not within project root {project_root}"

package_json = project_root / "package.json"
if not package_json.exists():
return False, f"No package.json found at {project_root}"

try:
source_file.relative_to(module_root)
except ValueError:
return False, (
f"Source file {source_file} is not within module root {module_root}. "
f"Check the 'codeflash.moduleRoot' setting in package.json."
)

return True, ""


def infer_js_module_root(source_file: Path, project_root: Path | None = None) -> Path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are functions for this purpose we can strength the logic there.
collect_js_setup_info

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or look into detect_module_root in config_js.py

"""Infer the JavaScript/TypeScript module root for a source file.

Uses find_node_project_root to locate package.json, then detect_module_root
to determine the module root from package.json fields and directory conventions.

Falls back to the source file's parent directory if no package.json is found.

Returns:
Absolute path to the inferred module root.

"""
source_file = source_file.resolve()

if project_root is None:
project_root = find_node_project_root(source_file)

if project_root is None:
return source_file.parent

project_root = project_root.resolve()
package_json_path = project_root / "package.json"
package_data = get_package_json_data(package_json_path)

if package_data is None:
return project_root

detected = detect_module_root(project_root, package_data)
return (project_root / detected).resolve()
72 changes: 71 additions & 1 deletion codeflash/languages/javascript/function_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import hashlib
import json
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand All @@ -14,6 +15,7 @@
TESTGEN_LIMIT_ERROR,
TOTAL_LOOPING_TIME_EFFECTIVE,
)
from codeflash.code_utils.config_js_validation import infer_js_module_root, validate_js_module_resolution
from codeflash.either import Failure, Success
from codeflash.models.models import (
CodeOptimizationContext,
Expand All @@ -33,9 +35,77 @@


class JavaScriptFunctionOptimizer(FunctionOptimizer):
def try_correct_module_root(self) -> bool:
"""Attempt to auto-correct a misconfigured module root.

Returns True if the module root was corrected, False if no correction was needed.
"""
if self.args is None:
return False
source_file = self.function_to_optimize.file_path
project_root = self.project_root
module_root = Path(self.args.module_root).resolve()

valid, _ = validate_js_module_resolution(source_file, project_root, module_root)
if valid:
return False

inferred = infer_js_module_root(source_file, project_root)

try:
source_file.resolve().relative_to(inferred)
except ValueError:
return False

logger.info(f"Auto-correcting module root from {module_root} to {inferred}")
self.args.module_root = inferred
self.args.project_root = project_root
self.project_root = project_root

package_json_path = project_root / "package.json"
if package_json_path.exists():
try:
with package_json_path.open(encoding="utf-8") as f:
doc = json.load(f)

relative_module_root = inferred.relative_to(project_root).as_posix()
codeflash_section = doc.get("codeflash", {})
if not isinstance(codeflash_section, dict):
codeflash_section = {}
codeflash_section["moduleRoot"] = relative_module_root
doc["codeflash"] = codeflash_section

with package_json_path.open("w", encoding="utf-8") as f:
json.dump(doc, f, indent=2)
f.write("\n")

from codeflash.code_utils.config_js import PACKAGE_JSON_DATA_CACHE

PACKAGE_JSON_DATA_CACHE.pop(package_json_path, None)
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"Could not update package.json with corrected module root: {e}")

return True

def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
self.try_correct_module_root()

if self.args is None:
return super().can_be_optimized()

source_file = self.function_to_optimize.file_path
project_root = self.project_root
module_root = Path(self.args.module_root).resolve()

valid, error = validate_js_module_resolution(source_file, project_root, module_root)
if not valid:
return Failure(f"Cannot optimize '{self.function_to_optimize.function_name}': {error}")

return super().can_be_optimized()

def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
from codeflash.languages.language_enum import Language

language = Language(self.function_to_optimize.language)
lang_support = get_language_support(language)
Expand Down
128 changes: 128 additions & 0 deletions tests/code_utils/test_config_js_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Tests for JavaScript/TypeScript module resolution validation utilities."""

from __future__ import annotations

import json
from typing import TYPE_CHECKING

from codeflash.code_utils.config_js_validation import infer_js_module_root, validate_js_module_resolution

if TYPE_CHECKING:
from pathlib import Path


class TestValidateJsModuleResolution:
def test_valid_source_in_module_root(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "package.json").write_text("{}", encoding="utf-8")
src = project_root / "src"
src.mkdir()
source_file = src / "index.js"
source_file.write_text("export function foo() {}", encoding="utf-8")

valid, error = validate_js_module_resolution(source_file, project_root, src)
assert valid is True
assert error == ""

def test_source_does_not_exist(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "package.json").write_text("{}", encoding="utf-8")
source_file = project_root / "src" / "missing.js"

valid, error = validate_js_module_resolution(source_file, project_root, project_root)
assert valid is False
assert "does not exist" in error

def test_source_outside_project_root(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "package.json").write_text("{}", encoding="utf-8")
outside_file = tmp_path / "outside.js"
outside_file.write_text("export function foo() {}", encoding="utf-8")

valid, error = validate_js_module_resolution(outside_file, project_root, project_root)
assert valid is False
assert "not within project root" in error

def test_no_package_json(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
source_file = project_root / "index.js"
source_file.write_text("export function foo() {}", encoding="utf-8")

valid, error = validate_js_module_resolution(source_file, project_root, project_root)
assert valid is False
assert "No package.json" in error

def test_source_outside_module_root(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "package.json").write_text("{}", encoding="utf-8")
src = project_root / "src"
src.mkdir()
other = project_root / "other"
other.mkdir()
source_file = other / "index.js"
source_file.write_text("export function foo() {}", encoding="utf-8")

valid, error = validate_js_module_resolution(source_file, project_root, src)
assert valid is False
assert "not within module root" in error
assert "moduleRoot" in error

def test_module_root_equals_project_root(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "package.json").write_text("{}", encoding="utf-8")
source_file = project_root / "index.js"
source_file.write_text("export function foo() {}", encoding="utf-8")

valid, error = validate_js_module_resolution(source_file, project_root, project_root)
assert valid is True
assert error == ""


class TestInferJsModuleRoot:
def test_infers_src_directory(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
src = project_root / "src"
src.mkdir()
source_file = src / "index.js"
source_file.write_text("export function foo() {}", encoding="utf-8")
(project_root / "package.json").write_text("{}", encoding="utf-8")

result = infer_js_module_root(source_file, project_root)
assert result == src.resolve()

def test_infers_from_package_json_main_field(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
app = project_root / "app"
app.mkdir()
source_file = app / "index.js"
source_file.write_text("export function foo() {}", encoding="utf-8")
(project_root / "package.json").write_text(json.dumps({"main": "app/index.js"}), encoding="utf-8")

result = infer_js_module_root(source_file, project_root)
assert result == app.resolve()

def test_falls_back_to_project_root(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
source_file = project_root / "index.js"
source_file.write_text("export function foo() {}", encoding="utf-8")
(project_root / "package.json").write_text("{}", encoding="utf-8")

result = infer_js_module_root(source_file, project_root)
assert result == project_root.resolve()

def test_falls_back_to_parent_without_package_json(self, tmp_path: Path) -> None:
source_file = tmp_path / "standalone" / "index.js"
source_file.parent.mkdir(parents=True)
source_file.write_text("export function foo() {}", encoding="utf-8")

result = infer_js_module_root(source_file, project_root=None)
assert result == source_file.parent.resolve()
Loading
Loading