diff --git a/codeflash-benchmark/codeflash_benchmark/__init__.py b/codeflash-benchmark/codeflash_benchmark/__init__.py new file mode 100644 index 000000000..5ce074bab --- /dev/null +++ b/codeflash-benchmark/codeflash_benchmark/__init__.py @@ -0,0 +1,3 @@ +"""CodeFlash Benchmark - Pytest benchmarking plugin for codeflash.ai.""" + +__version__ = "0.1.0" diff --git a/codeflash-benchmark/codeflash_benchmark/plugin.py b/codeflash-benchmark/codeflash_benchmark/plugin.py new file mode 100644 index 000000000..7d6c82927 --- /dev/null +++ b/codeflash-benchmark/codeflash_benchmark/plugin.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import importlib.util + +import pytest + +from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin + +PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None + + +def pytest_configure(config: pytest.Config) -> None: + """Register the benchmark marker and disable conflicting plugins.""" + config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing") + + if config.getoption("--codeflash-trace") and PYTEST_BENCHMARK_INSTALLED: + config.option.benchmark_disable = True + config.pluginmanager.set_blocked("pytest_benchmark") + config.pluginmanager.set_blocked("pytest-benchmark") + + +def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + "--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks" + ) + + +@pytest.fixture +def benchmark(request: pytest.FixtureRequest) -> object: + """Benchmark fixture that works with or without pytest-benchmark installed.""" + config = request.config + + # If --codeflash-trace is enabled, use our implementation + if config.getoption("--codeflash-trace"): + return codeflash_benchmark_plugin.Benchmark(request) + + # If pytest-benchmark is installed and --codeflash-trace is not enabled, + # return the normal pytest-benchmark fixture + if PYTEST_BENCHMARK_INSTALLED: + from pytest_benchmark.fixture import BenchmarkFixture as BSF # noqa: N814 + + bs = getattr(config, "_benchmarksession", None) + if bs and bs.skip: + pytest.skip("Benchmarks are skipped (--benchmark-skip was used).") + + node = request.node + marker = node.get_closest_marker("benchmark") + options = dict(marker.kwargs) if marker else {} + + if bs: + return BSF( + node, + add_stats=bs.benchmarks.append, + logger=bs.logger, + warner=request.node.warn, + disabled=bs.disabled, + **dict(bs.options, **options), + ) + return lambda func, *args, **kwargs: func(*args, **kwargs) + + return lambda func, *args, **kwargs: func(*args, **kwargs) diff --git a/codeflash-benchmark/pyproject.toml b/codeflash-benchmark/pyproject.toml new file mode 100644 index 000000000..3c1986c99 --- /dev/null +++ b/codeflash-benchmark/pyproject.toml @@ -0,0 +1,32 @@ +[project] +name = "codeflash-benchmark" +version = "0.1.0" +description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization" +authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }] +requires-python = ">=3.9" +readme = "README.md" +license = {text = "BSL-1.1"} +keywords = [ + "codeflash", + "benchmark", + "pytest", + "performance", + "testing", +] +dependencies = [ + "pytest>=7.0.0,!=8.3.4", +] + +[project.urls] +Homepage = "https://codeflash.ai" +Repository = "https://github.com/codeflash-ai/codeflash-benchmark" + +[project.entry-points.pytest11] +codeflash-benchmark = "codeflash_benchmark.plugin" + +[build-system] +requires = ["setuptools>=45", "wheel", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["codeflash_benchmark"] \ No newline at end of file diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index f6fba65af..58053bc4f 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from codeflash.models.models import BenchmarkKey -IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None +PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None class CodeFlashBenchmarkPlugin: @@ -251,8 +251,12 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202 def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202 """Actual benchmark implementation.""" + node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None) + if node_path is None: + raise RuntimeError("Unable to determine test file path from pytest node") + benchmark_module_path = module_name_from_file_path( - Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True + Path(str(node_path)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True ) benchmark_function_name = self.request.node.name @@ -282,26 +286,3 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003 codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin() - - -def pytest_configure(config: pytest.Config) -> None: - """Register the benchmark marker and disable conflicting plugins.""" - config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing") - - if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED: - config.option.benchmark_disable = True - config.pluginmanager.set_blocked("pytest_benchmark") - config.pluginmanager.set_blocked("pytest-benchmark") - - -def pytest_addoption(parser: pytest.Parser) -> None: - parser.addoption( - "--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks" - ) - - -@pytest.fixture -def benchmark(request: pytest.FixtureRequest) -> object: - if not request.config.getoption("--codeflash-trace"): - return lambda func, *args, **kwargs: func(*args, **kwargs) - return codeflash_benchmark_plugin.Benchmark(request) diff --git a/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py b/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py index f8650d9fb..ccbdb34b7 100644 --- a/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py +++ b/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py @@ -7,11 +7,14 @@ benchmarks_root = sys.argv[1] tests_root = sys.argv[2] trace_file = sys.argv[3] -# current working directory project_root = Path.cwd() + if __name__ == "__main__": import pytest + orig_recursion_limit = sys.getrecursionlimit() + sys.setrecursionlimit(orig_recursion_limit * 2) + try: codeflash_benchmark_plugin.setup(trace_file, project_root) codeflash_trace.setup(trace_file) @@ -32,9 +35,12 @@ "addopts=", ], plugins=[codeflash_benchmark_plugin], - ) # Errors will be printed to stdout, not stderr - + ) except Exception as e: print(f"Failed to collect tests: {e!s}", file=sys.stderr) exitcode = -1 + finally: + # Restore the original recursion limit + sys.setrecursionlimit(orig_recursion_limit) + sys.exit(exitcode) diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index c2e1889db..f925f19d8 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import sqlite3 import textwrap from pathlib import Path @@ -14,6 +15,8 @@ if TYPE_CHECKING: from collections.abc import Generator +benchmark_context_cleaner = re.compile(r"[^a-zA-Z0-9_]+") + def get_next_arg_and_return( trace_file: str, @@ -46,6 +49,16 @@ def get_function_alias(module: str, function_name: str) -> str: return "_".join(module.split(".")) + "_" + function_name +def get_unique_test_name(module: str, function_name: str, benchmark_name: str, class_name: str | None = None) -> str: + clean_benchmark = benchmark_context_cleaner.sub("_", benchmark_name).strip("_") + + base_alias = get_function_alias(module, function_name) + if class_name: + class_alias = get_function_alias(module, class_name) + return f"{class_alias}_{function_name}_{clean_benchmark}" + return f"{base_alias}_{clean_benchmark}" + + def create_trace_replay_test_code( trace_file: str, functions_data: list[dict[str, Any]], @@ -209,7 +222,8 @@ def create_trace_replay_test_code( formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ") test_template += " " if test_framework == "unittest" else "" - test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n" + unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name) + test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n" return imports + "\n" + metadata + "\n" + test_template diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 545ea4a0a..d4994d0c3 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -150,11 +150,13 @@ def __init__(self, function_names_to_find: set[str]) -> None: self.imported_modules: set[str] = set() self.has_dynamic_imports: bool = False self.wildcard_modules: set[str] = set() + # Track aliases: alias_name -> original_name + self.alias_mapping: dict[str, str] = {} # Precompute function_names for prefix search # For prefix match, store mapping from prefix-root to candidates for O(1) matching self._exact_names = function_names_to_find - self._prefix_roots = {} + self._prefix_roots: dict[str, list[str]] = {} for name in function_names_to_find: if "." in name: root = name.split(".", 1)[0] @@ -206,6 +208,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: imported_name = alias.asname if alias.asname else aname self.imported_modules.add(imported_name) + if alias.asname: + self.alias_mapping[imported_name] = aname + # Fast check for dynamic import if mod == "importlib" and aname == "import_module": self.has_dynamic_imports = True @@ -222,7 +227,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: self.found_qualified_name = qname return - # Fast prefix match: only for relevant roots prefix = qname + "." # Only bother if one of the targets startswith the prefix-root candidates = proots.get(qname, ()) @@ -247,6 +251,18 @@ def visit_Attribute(self, node: ast.Attribute) -> None: self.found_qualified_name = node.attr return + if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules: + for target_func in self.function_names_to_find: + if "." in target_func: + class_name, method_name = target_func.rsplit(".", 1) + if node.attr == method_name: + imported_name = node.value.id + original_name = self.alias_mapping.get(imported_name, imported_name) + if original_name == class_name: + self.found_any_target_function = True + self.found_qualified_name = target_func + return + # Check if this is accessing a target function through a dynamically imported module # Only if we've detected dynamic imports are being used if self.has_dynamic_imports and node.attr in self.function_names_to_find: diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 4fcafc50b..7ca1ee6da 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -204,6 +204,7 @@ def get_functions_to_optimize( functions[file] = [found_function] else: logger.info("Finding all functions modified in the current git diff ...") + console.rule() ph("cli-optimizing-git-diff") functions = get_functions_within_git_diff() filtered_modified_functions, functions_count = filter_functions( diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 369fd51bd..529b76980 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -515,6 +515,7 @@ def group_by_benchmarks( benchmark_replay_test_dir.resolve() / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", project_root, + traverse_up=True, ) for test_result in self.test_results: if test_result.test_type == TestType.REPLAY_TEST: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 1844819cc..6794e626e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -840,6 +840,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio f"{concolic_coverage_test_files_count} concolic coverage test file" f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}" ) + console.rule() return unique_instrumented_test_files def generate_tests_and_optimizations( diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 01aa515c3..d66c3fcf0 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -106,7 +106,7 @@ def run_benchmarks( for file in file_path_to_source_code: with file.open("w", encoding="utf8") as f: f.write(file_path_to_source_code[file]) - + console.rule() return function_benchmark_timings, total_benchmark_timings def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: diff --git a/pyproject.toml b/pyproject.toml index 3dba0e759..58bdeae4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "line_profiler>=4.2.0", "platformdirs>=4.3.7", "pygls>=1.3.1", + "codeflash-benchmark", ] [project.urls] @@ -262,6 +263,12 @@ skip-magic-trailing-comma = true [tool.hatch.version] source = "uv-dynamic-versioning" +[tool.uv] +workspace = { members = ["codeflash-benchmark"] } + +[tool.uv.sources] +codeflash-benchmark = { workspace = true } + [tool.uv-dynamic-versioning] enable = true style = "pep440" @@ -300,5 +307,3 @@ markers = [ requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" -[project.entry-points.pytest11] -codeflash = "codeflash.benchmarking.plugin.plugin" diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index 3d2f21b66..346153674 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -364,7 +364,7 @@ def test_run_and_parse_picklepatch() -> None: test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" test_type = TestType.REPLAY_TEST - replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket" + replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch" func_optimizer = opt.create_function_optimizer(func) func_optimizer.test_files = TestFiles( test_files=[ @@ -388,7 +388,7 @@ def test_run_and_parse_picklepatch() -> None: ) assert len(test_results_unused_socket) == 1 assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" - assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket" + assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch" assert test_results_unused_socket.test_results[0].did_pass == True # Replace with optimized candidate @@ -439,7 +439,7 @@ def bubble_sort_with_unused_socket(data_container): test_type = TestType.REPLAY_TEST func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path)) - replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch" func_optimizer = opt.create_function_optimizer(func) func_optimizer.test_files = TestFiles( test_files=[ @@ -467,7 +467,7 @@ def bubble_sort_with_unused_socket(data_container): assert test_results_used_socket.test_results[ 0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" assert test_results_used_socket.test_results[ - 0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + 0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch" assert test_results_used_socket.test_results[0].did_pass is False print("test results used socket") print(test_results_used_socket) @@ -498,7 +498,7 @@ def bubble_sort_with_used_socket(data_container): assert test_results_used_socket.test_results[ 0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" assert test_results_used_socket.test_results[ - 0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + 0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch" assert test_results_used_socket.test_results[0].did_pass is False # Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined. diff --git a/tests/test_trace_benchmarks.py b/tests/test_trace_benchmarks.py index adc00847b..c16150fba 100644 --- a/tests/test_trace_benchmarks.py +++ b/tests/test_trace_benchmarks.py @@ -33,7 +33,7 @@ def test_trace_benchmarks() -> None: function_calls = cursor.fetchall() # Assert the length of function calls - assert len(function_calls) == 7, f"Expected 7 function calls, but got {len(function_calls)}" + assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}" bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix() process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix() @@ -95,13 +95,13 @@ def test_trace_benchmarks() -> None: functions = ['sort_class', 'sort_static', 'sorter'] trace_file_path = r"{output_file.as_posix()}" -def test_code_to_optimize_bubble_sort_codeflash_trace_sorter(): +def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_sort(): for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_sort", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100): args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl) ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs) -def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sorter(): +def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sorter_test_class_sort(): for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort", function_name="sorter", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100): args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl) @@ -113,7 +113,7 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sorter(): else: ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sorter(*args, **kwargs) -def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_class(): +def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_class_test_class_sort2(): for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort2", function_name="sort_class", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100): args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl) @@ -121,13 +121,13 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_class(): raise ValueError("No arguments provided for the method.") ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_class(*args[1:], **kwargs) -def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_static(): +def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_static_test_class_sort3(): for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort3", function_name="sort_static", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100): args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl) ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_static(*args, **kwargs) -def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__(): +def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init___test_class_sort4(): for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort4", function_name="__init__", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100): args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl) @@ -156,13 +156,13 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__(): functions = ['compute_and_sort', 'sorter'] trace_file_path = r"{output_file}" -def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort(): +def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort_test_compute_and_sort(): for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_compute_and_sort", function_name="compute_and_sort", file_path=r"{process_and_bubble_sort_path}", num_to_get=100): args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl) ret = code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort(*args, **kwargs) -def test_code_to_optimize_bubble_sort_codeflash_trace_sorter(): +def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func(): for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_no_func", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100): args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl) diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 101aa2671..7d453d624 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -1298,3 +1298,44 @@ def test_unrelated(): assert len(filtered_tests) == 1 assert "target_module.target_function" in filtered_tests assert "unrelated_module.unrelated_function" not in filtered_tests + + +def test_analyze_imports_aliased_class_method(): + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from pydantic_ai.profiles.google import ( + GoogleJsonSchemaTransformer as pydantic_ai_profiles_google_GoogleJsonSchemaTransformer, +) + +def test_target(): + ret = pydantic_ai_profiles_google_GoogleJsonSchemaTransformer.transform(*args, **kwargs) + assert ret is not None +""" + test_file.write_text(test_content) + + target_functions = {"GoogleJsonSchemaTransformer.transform"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + + +def test_analyze_imports_aliased_class_method_negative(): + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from pydantic_ai.profiles.google import ( + GoogleJsonSchemaTransformer as pydantic_ai_profiles_google_GoogleJsonSchemaTransformer, +) + +def test_target(): + ret = pydantic_ai_profiles_google_GoogleJsonSchemaTransformer.validate(*args, **kwargs) + assert ret is not None +""" + test_file.write_text(test_content) + + # Looking for transform but code uses validate - should not match + target_functions = {"GoogleJsonSchemaTransformer.transform"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False diff --git a/uv.lock b/uv.lock index d8a2b77c6..2395b196f 100644 --- a/uv.lock +++ b/uv.lock @@ -8,6 +8,12 @@ resolution-markers = [ "python_full_version < '3.10'", ] +[manifest] +members = [ + "codeflash", + "codeflash-benchmark", +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -210,6 +216,7 @@ source = { editable = "." } dependencies = [ { name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "click", version = "8.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "codeflash-benchmark" }, { name = "coverage" }, { name = "crosshair-tool" }, { name = "dill" }, @@ -239,6 +246,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "codeflash-benchmark" }, { name = "ipython", version = "8.18.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -268,6 +276,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "click", specifier = ">=8.1.0" }, + { name = "codeflash-benchmark", editable = "codeflash-benchmark" }, { name = "coverage", specifier = ">=7.6.4" }, { name = "crosshair-tool", specifier = ">=0.0.78" }, { name = "dill", specifier = ">=0.3.8" }, @@ -297,6 +306,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "codeflash-benchmark", editable = "codeflash-benchmark" }, { name = "ipython", specifier = ">=8.12.0" }, { name = "lxml-stubs", specifier = ">=0.5.1" }, { name = "mypy", specifier = ">=1.13" }, @@ -320,6 +330,17 @@ dev = [ { name = "uv", specifier = ">=0.6.2" }, ] +[[package]] +name = "codeflash-benchmark" +version = "0.1.0" +source = { editable = "codeflash-benchmark" } +dependencies = [ + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [{ name = "pytest", specifier = ">=7.0.0,!=8.3.4" }] + [[package]] name = "colorama" version = "0.4.6"