Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.config_parser import find_pyproject_toml

ImportErrorPattern = re.compile(r"^.*ModuleNotFoundError.*$", re.MULTILINE)


@contextmanager
def custom_addopts() -> None:
Expand Down
13 changes: 12 additions & 1 deletion codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@

import pytest
from pydantic.dataclasses import dataclass
from rich.panel import Panel
from rich.text import Text

from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.code_utils import (
ImportErrorPattern,
custom_addopts,
get_run_tmp_file,
module_name_from_file_path,
)
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType

Expand Down Expand Up @@ -180,6 +187,10 @@ def discover_tests_pytest(
logger.warning(
f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}\n {error_section}"
)
if "ModuleNotFoundError" in result.stdout:
match = ImportErrorPattern.search(result.stdout).group()
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
console.print(panel)

elif 0 <= exitcode <= 5:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}")
Expand Down
7 changes: 7 additions & 0 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
ImportErrorPattern,
cleanup_paths,
file_name_from_test_module_name,
get_run_tmp_file,
Expand Down Expand Up @@ -1192,6 +1193,12 @@ def run_and_parse_tests(
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n"
)
if "ModuleNotFoundError" in run_result.stdout:
from rich.text import Text

match = ImportErrorPattern.search(run_result.stdout).group()
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
console.print(panel)
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
results, coverage_results = parse_test_results(
test_xml_path=result_file_path,
Expand Down
51 changes: 51 additions & 0 deletions tests/test_test_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import re

import os
import tempfile
from pathlib import Path

from codeflash.code_utils.code_utils import ImportErrorPattern
from codeflash.models.models import TestFile, TestFiles, TestType
from codeflash.verification.parse_test_output import parse_test_xml
from codeflash.verification.test_runner import run_behavioral_tests
Expand Down Expand Up @@ -96,3 +99,51 @@ def test_sort():
)
assert results[0].did_pass, "Test did not pass as expected"
result_file.unlink(missing_ok=True)

code = """import torch
def sorter(arr):
print(torch.ones(1))
arr.sort()
return arr

def test_sort():
arr = [5, 4, 3, 2, 1, 0]
output = sorter(arr)
assert output == [0, 1, 2, 3, 4, 5]
"""
cur_dir_path = Path(__file__).resolve().parent
config = TestConfig(
tests_root=cur_dir_path,
project_root_path=cur_dir_path,
test_framework="pytest",
tests_project_rootdir=cur_dir_path.parent,
)

test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_TRACER_DISABLE"] = "1"
if "PYTHONPATH" not in test_env:
test_env["PYTHONPATH"] = str(config.project_root_path)
else:
test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path)

with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
test_files = TestFiles(
test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
)
fp.write(code.encode("utf-8"))
fp.flush()
result_file, process, _, _ = run_behavioral_tests(
test_files,
test_framework=config.test_framework,
cwd=Path(config.project_root_path),
test_env=test_env,
pytest_timeout=1,
pytest_target_runtime_seconds=1,
)
results = parse_test_xml(
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
)
match = ImportErrorPattern.search(process.stdout).group()
assert match=="E ModuleNotFoundError: No module named 'torch'"
result_file.unlink(missing_ok=True)
Loading