diff --git a/pyproject.toml b/pyproject.toml index 11ef2b9..2719fa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ include = ["/README.md", "/Makefile", "/pytest_examples", "/tests"] [project] name = "pytest-examples" -version = "0.0.15" +version = "0.0.16" description = "Pytest plugin for testing examples in docstrings and markdown files." authors = [ {name = "Samuel Colvin", email = "s@muelcolvin.com"}, diff --git a/pytest_examples/eval_example.py b/pytest_examples/eval_example.py index ad5e5ca..e2f19e7 100644 --- a/pytest_examples/eval_example.py +++ b/pytest_examples/eval_example.py @@ -9,7 +9,7 @@ from .config import DEFAULT_LINE_LENGTH, ExamplesConfig from .lint import FormatError, black_check, black_format, ruff_check, ruff_format -from .run_code import InsertPrintStatements, run_code +from .run_code import IncludePrint, InsertPrintStatements, run_code if TYPE_CHECKING: from typing import Literal @@ -29,6 +29,7 @@ def __init__(self, *, tmp_path: Path, pytest_request: pytest.FixtureRequest): self.to_update: list[CodeExample] = [] self.config: ExamplesConfig = ExamplesConfig() self.print_callback: Callable[[str], str] | None = None + self.include_print: IncludePrint | None = None def set_config( self, @@ -172,6 +173,7 @@ def _run( config=self.config, enable_print_mock=enable_print_mock, print_callback=self.print_callback, + include_print=self.include_print, module_globals=module_globals, call=call, ) diff --git a/pytest_examples/run_code.py b/pytest_examples/run_code.py index 737cd28..903f414 100644 --- a/pytest_examples/run_code.py +++ b/pytest_examples/run_code.py @@ -7,6 +7,7 @@ import inspect import re import sys +from collections.abc import Sequence from dataclasses import dataclass from importlib.abc import Loader from pathlib import Path @@ -24,9 +25,10 @@ from .config import ExamplesConfig from .find_examples import CodeExample -__all__ = 'run_code', 'InsertPrintStatements' +__all__ = 'run_code', 'InsertPrintStatements', 'IncludePrint' parent_frame_id = 4 if sys.version_info >= (3, 8) else 3 +IncludePrint = Callable[[Path, inspect.FrameInfo, Sequence[Any]], bool] def run_code( @@ -37,6 +39,7 @@ def run_code( config: ExamplesConfig, enable_print_mock: bool, print_callback: Callable[[str], str] | None, + include_print: IncludePrint | None, module_globals: dict[str, Any] | None, call: str | None, ) -> tuple[InsertPrintStatements, dict[str, Any]]: @@ -49,6 +52,7 @@ def run_code( config: The `ExamplesConfig` to use. enable_print_mock: If True, mock the `print` function. print_callback: If not None, a callback to call on `print`. + include_print: If not None, a function to call to determine if the print statement should be included. module_globals: The extra globals to add before calling the module. call: If not None, a (coroutine) function to call in the module. @@ -63,7 +67,7 @@ def run_code( module = importlib.util.module_from_spec(spec) # does nothing if insert_print_statements is False - insert_print = InsertPrintStatements(python_file, config, enable_print_mock, print_callback) + insert_print = InsertPrintStatements(python_file, config, enable_print_mock, print_callback, include_print) if module_globals: module.__dict__.update(module_globals) @@ -141,26 +145,40 @@ def not_print(*args): class MockPrintFunction: - def __init__(self, file: Path) -> None: + __slots__ = 'file', 'statements', 'include_print' + + def __init__(self, file: Path, include_print: IncludePrint | None) -> None: self.file = file self.statements: list[PrintStatement] = [] + self.include_print = include_print def __call__(self, *args: Any, sep: str = ' ', **kwargs: Any) -> None: frame = inspect.stack()[parent_frame_id] - if self.file.samefile(frame.filename): + if self._include_file(frame, args): # -1 to account for the line number being 1-indexed s = PrintStatement(frame.lineno, sep, [Arg(arg) for arg in args]) self.statements.append(s) + def _include_file(self, frame: inspect.FrameInfo, args: Sequence[Any]) -> bool: + if self.include_print: + return self.include_print(self.file, frame, args) + else: + return self.file.samefile(frame.filename) + class InsertPrintStatements: def __init__( - self, python_path: Path, config: ExamplesConfig, enable: bool, print_callback: Callable[[str], str] | None + self, + python_path: Path, + config: ExamplesConfig, + enable: bool, + print_callback: Callable[[str], str] | None, + include_print: IncludePrint | None, ): self.file = python_path self.config = config - self.print_func = MockPrintFunction(python_path) if enable else None + self.print_func = MockPrintFunction(python_path, include_print) if enable else None self.print_callback = print_callback self.patch = None diff --git a/tests/test_insert_print.py b/tests/test_insert_print.py index 51271c7..911c521 100644 --- a/tests/test_insert_print.py +++ b/tests/test_insert_print.py @@ -432,3 +432,22 @@ async def main(): module_dict = eval_example.run_print_check(example, call='main') assert module_dict['main_called'] + + +def test_custom_include_print(tmp_path, eval_example): + # note this file is no written here as it's not required + md_file = tmp_path / 'test.md' + python_code = """ +print('yes') +#> yes +print('no') +""" + example = CodeExample.create(python_code, path=md_file) + eval_example.set_config(line_length=30) + + def custom_include_print(path, frame, args): + return 'yes' in args + + eval_example.include_print = custom_include_print + + eval_example.run_print_check(example, call='main') diff --git a/uv.lock b/uv.lock index 80b43d2..178d1f5 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.8" [[package]] @@ -53,7 +54,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -330,7 +331,7 @@ wheels = [ [[package]] name = "pytest-examples" -version = "0.0.15" +version = "0.0.16" source = { editable = "." } dependencies = [ { name = "black" },