Skip to content

Commit d1f28b2

Browse files
authored
Merge pull request #30 from George-Ogden/typed-marks
Typed Marks
2 parents c45b9a0 + 5e83620 commit d1f28b2

File tree

15 files changed

+217
-63
lines changed

15 files changed

+217
-63
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ jobs:
2525
with:
2626
python-versions: "['3.13']"
2727
timeout-minutes: 10
28-
pytest-flags: -vv tests
28+
pytest-flags: -vv tests -n auto

mypy_pytest_plugin/error_codes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@
120120
category="Pytest",
121121
)
122122

123+
UNKNOWN_MARK: Final[ErrorCode] = ErrorCode(
124+
"unknown-mark",
125+
"Mark name is not recognized as a pre-defined or user-defined mark.",
126+
category="Pytest",
127+
)
128+
123129
ITERABLE_SEQUENCE: Final[ErrorCode] = ErrorCode(
124130
"iterable-sequence",
125131
""""Sequence" passed into a function expecting "Iterable" in a test.""",

mypy_pytest_plugin/fixture.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .error_codes import DUPLICATE_FIXTURE, INVALID_FIXTURE_SCOPE, MARKED_FIXTURE, REQUEST_KEYWORD
2424
from .fullname import Fullname
2525
from .test_argument import TestArgument
26+
from .types_module import TYPES_MODULE
2627

2728
FixtureScope = enum.IntEnum(
2829
"FixtureScope", ["function", "class", "module", "package", "session", "unknown"]
@@ -93,6 +94,20 @@ def name(self) -> str:
9394
def module_name(self) -> Fullname:
9495
return self.fullname.module_name
9596

97+
def as_fixture_type(self, *, decorator: Decorator, checker: TypeChecker) -> Type:
98+
assert decorator.func.type is not None
99+
return checker.named_generic_type(
100+
f"{TYPES_MODULE}.FixtureType",
101+
[
102+
LiteralType(self.scope, fallback=checker.named_type("builtins.object")),
103+
decorator.func.type,
104+
LiteralType(
105+
decorator.func.is_generator, fallback=checker.named_type("builtins.object")
106+
),
107+
LiteralType(decorator.fullname, fallback=checker.named_type("builtins.object")),
108+
],
109+
)
110+
96111

97112
@dataclass(frozen=True, slots=True)
98113
class FixtureParser:

mypy_pytest_plugin/mark_checker.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from collections.abc import Sequence
2+
from dataclasses import dataclass
3+
import functools
4+
import itertools
5+
6+
from mypy.checker import TypeChecker
7+
from mypy.nodes import MemberExpr
8+
from mypy.subtypes import is_same_type
9+
10+
from .error_codes import UNKNOWN_MARK
11+
from .pytest_config_manager import PytestConfigManager
12+
13+
14+
@dataclass(frozen=True)
15+
class MarkChecker:
16+
checker: TypeChecker
17+
18+
def check_attribute(self, expr: MemberExpr) -> None:
19+
if is_same_type(
20+
self.checker.lookup_type(expr.expr), self.checker.named_type("pytest.MarkGenerator")
21+
) and not self.is_valid_mark(expr.name):
22+
error_msg = f"Invalid mark name {expr.name!r}."
23+
note_prefix = f"Expected a predefined mark (one of {self.predefined_names!r}) or "
24+
if self.user_defined_names:
25+
note_suffix = f"a user defined mark (one of {self.user_defined_names!r})."
26+
else:
27+
note_suffix = "see https://docs.pytest.org/en/stable/how-to/mark.html for how to register marks."
28+
self.checker.fail(error_msg, context=expr, code=UNKNOWN_MARK)
29+
self.checker.note(note_prefix + note_suffix, context=expr, code=UNKNOWN_MARK)
30+
31+
def is_valid_mark(self, name: str) -> bool:
32+
return not name.startswith("_") and (name in self._mark_names_index)
33+
34+
@functools.cached_property
35+
def _mark_names_index(self) -> set[str]:
36+
return set(itertools.chain(self.predefined_names, self.user_defined_names))
37+
38+
@functools.cached_property
39+
def predefined_names(self) -> Sequence[str]:
40+
return [
41+
name
42+
for name in self.checker.named_type("pytest.MarkGenerator").type.names
43+
if not name.startswith("_")
44+
]
45+
46+
@functools.cached_property
47+
def user_defined_names(self) -> list[str]:
48+
return [
49+
name
50+
for line in PytestConfigManager.markers()
51+
if (name := line.split(":")[0].split("(")[0].strip()) and not name.startswith("_")
52+
]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from .mark_checker import MarkChecker
2+
from .test_utils import parse
3+
4+
5+
def _mark_name_test_body(name: str, valid: bool) -> None:
6+
pytest_source_mock = """
7+
from typing import Any
8+
9+
class MarkGenerator:
10+
skip: Any
11+
parametrize: Any
12+
def __getattr__(self, name: str) -> Any:
13+
raise NotImplementedError()
14+
def _config(self) -> None: ...
15+
"""
16+
parse_result = parse(pytest_source_mock, module_name="pytest")
17+
parse_result.accept_all()
18+
19+
mark_checker = MarkChecker(parse_result.checker)
20+
assert mark_checker.is_valid_mark(name) == valid
21+
22+
23+
def test_mark_name_starts_with_underscore() -> None:
24+
_mark_name_test_body("_parametrize", False)
25+
26+
27+
def test_mark_name_valid() -> None:
28+
_mark_name_test_body("parametrize", True)
29+
30+
31+
def test_mark_name_not_predefined() -> None:
32+
_mark_name_test_body("invalid", False)
33+
34+
35+
def test_mark_name_user_defined() -> None:
36+
_mark_name_test_body("used_for_testing", True)

mypy_pytest_plugin/plugin.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,18 @@
33
from typing import cast
44

55
from mypy.checker import TypeChecker
6-
from mypy.nodes import CallExpr, Decorator, Expression, MypyFile
6+
from mypy.nodes import CallExpr, Decorator, Expression, MemberExpr, MypyFile
77
from mypy.options import Options
8-
from mypy.plugin import (
9-
FunctionContext,
10-
MethodContext,
11-
Plugin,
12-
)
13-
from mypy.types import (
14-
CallableType,
15-
LiteralType,
16-
Type,
17-
)
8+
from mypy.plugin import AttributeContext, FunctionContext, MethodContext, Plugin
9+
from mypy.types import CallableType, Type
1810

1911
from .defer import DeferralError
2012
from .excluded_test_checker import ExcludedTestChecker
2113
from .fixture import Fixture
2214
from .fixture_manager import FixtureManager
2315
from .fullname import Fullname
2416
from .iterable_sequence_checker import IterableSequenceChecker
17+
from .mark_checker import MarkChecker
2518
from .mock_call_checker import FunctionMockCallChecker, MethodMockCallChecker
2619
from .test_body_ranges import TestBodyRanges
2720
from .test_info import TestInfo
@@ -65,6 +58,23 @@ def module_to_dep(cls, module: str | Fullname) -> tuple[int, str, int]:
6558
module = str(module)
6659
return (10, module, -1)
6760

61+
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
62+
if fullname.startswith("_pytest.mark.structures.MarkGenerator"):
63+
return self.check_mark
64+
return None
65+
66+
@classmethod
67+
def check_mark(cls, ctx: AttributeContext) -> Type:
68+
if ctx.api.path == "test_samples/mark_test.py":
69+
...
70+
if (
71+
not ctx.is_lvalue
72+
and isinstance(checker := ctx.api, TypeChecker)
73+
and isinstance(expr := ctx.context, MemberExpr)
74+
):
75+
MarkChecker(checker).check_attribute(expr)
76+
return ctx.default_attr_type
77+
6878
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
6979
if fullname.startswith("unittest.mock"):
7080
return functools.partial(FunctionMockCallChecker.check_mock_calls, fullname=fullname)
@@ -126,7 +136,7 @@ def _check_pytest_structure(cls, ctx: MethodContext | FunctionContext) -> Type:
126136
cls._update_return_type(ctx.default_return_type, ctx.api)
127137
if not Fixture.is_fixture_and_mark(ctx.context, checker=ctx.api):
128138
if fixture := Fixture.from_decorator(ctx.context, checker=ctx.api):
129-
return cls._fixture_type(fixture, decorator=ctx.context, checker=ctx.api)
139+
return fixture.as_fixture_type(decorator=ctx.context, checker=ctx.api)
130140
ignored_testnames = ExcludedTestChecker.ignored_test_names(
131141
ctx.api.tree.defs, ctx.api
132142
)
@@ -136,27 +146,6 @@ def _check_pytest_structure(cls, ctx: MethodContext | FunctionContext) -> Type:
136146
cls._check_decorators(ctx.context, ctx.api)
137147
return ctx.default_return_type
138148

139-
@classmethod
140-
def _check_decorators(cls, node: Decorator, checker: TypeChecker) -> None:
141-
test_info = TestInfo.from_fn_def(node, checker=checker)
142-
if test_info is not None:
143-
test_info.check()
144-
145-
@classmethod
146-
def _fixture_type(cls, fixture: Fixture, *, decorator: Decorator, checker: TypeChecker) -> Type:
147-
assert decorator.func.type is not None
148-
return checker.named_generic_type(
149-
f"{TYPES_MODULE}.FixtureType",
150-
[
151-
LiteralType(fixture.scope, fallback=checker.named_type("builtins.object")),
152-
decorator.func.type,
153-
LiteralType(
154-
decorator.func.is_generator, fallback=checker.named_type("builtins.object")
155-
),
156-
LiteralType(decorator.fullname, fallback=checker.named_type("builtins.object")),
157-
],
158-
)
159-
160149
@classmethod
161150
def _update_return_type(cls, return_type: Type, checker: TypeChecker) -> None:
162151
if (
@@ -165,6 +154,12 @@ def _update_return_type(cls, return_type: Type, checker: TypeChecker) -> None:
165154
):
166155
return_type.fallback = checker.named_type(f"{TYPES_MODULE}.Testable")
167156

157+
@classmethod
158+
def _check_decorators(cls, node: Decorator, checker: TypeChecker) -> None:
159+
test_info = TestInfo.from_fn_def(node, checker=checker)
160+
if test_info is not None:
161+
test_info.check()
162+
168163

169164
def plugin(version: str) -> type[PytestPlugin]:
170165
return PytestPlugin
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import functools
2+
3+
import _pytest.config
4+
from pytest import Config, Session # noqa: PT013
5+
6+
7+
class PytestConfigManager:
8+
@classmethod
9+
@functools.cache
10+
def session(cls) -> Session:
11+
config = _pytest.config.get_config()
12+
config.parse(["-s", "--noconftest"])
13+
return Session.from_config(config)
14+
15+
@classmethod
16+
def config(cls) -> Config:
17+
return cls.session().config
18+
19+
@classmethod
20+
@functools.cache
21+
def file_patterns(cls) -> list[str]:
22+
return cls.config().getini("python_files")
23+
24+
@classmethod
25+
@functools.cache
26+
def fn_patterns(cls) -> list[str]:
27+
return cls.config().getini("python_functions")
28+
29+
@classmethod
30+
@functools.cache
31+
def markers(cls) -> list[str]:
32+
return cls.config().getini("markers")

mypy_pytest_plugin/test_name_checker.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import functools
44
from pathlib import Path
55

6-
import _pytest.config
7-
from _pytest.main import Session
86
from _pytest.pathlib import fnmatch_ex
97

8+
from .pytest_config_manager import PytestConfigManager
9+
1010

1111
class TestNameChecker:
1212
@classmethod
@@ -25,23 +25,6 @@ def _path_from_sections(cls, sections: MutableSequence[str]) -> Path:
2525
sections[-1] += ".py"
2626
return Path(*sections)
2727

28-
@classmethod
29-
@functools.cache
30-
def _session(cls) -> Session:
31-
config = _pytest.config.get_config()
32-
config.parse(["-s", "--noconftest"])
33-
return Session.from_config(config)
34-
35-
@classmethod
36-
@functools.cache
37-
def _file_patterns(cls) -> list[str]:
38-
return cls._session().config.getini("python_files")
39-
40-
@classmethod
41-
@functools.cache
42-
def _fn_patterns(cls) -> list[str]:
43-
return cls._session().config.getini("python_functions")
44-
4528
@classmethod
4629
@functools.cache
4730
def is_test_file_name(cls, name: str) -> bool:
@@ -51,11 +34,11 @@ def is_test_file_name(cls, name: str) -> bool:
5134
@classmethod
5235
@functools.cache
5336
def is_test_path(cls, path: Path) -> bool:
54-
return any(fnmatch_ex(pattern, path) for pattern in cls._file_patterns())
37+
return any(fnmatch_ex(pattern, path) for pattern in PytestConfigManager.file_patterns())
5538

5639
@classmethod
5740
def is_test_fn_name(cls, fn_name: str) -> bool:
5841
return any(
5942
fn_name.startswith(pattern) or fnmatch.fnmatch(pattern, fn_name)
60-
for pattern in cls._fn_patterns()
43+
for pattern in PytestConfigManager.fn_patterns()
6144
)

mypy_pytest_plugin/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def parse_multiple(modules: Sequence[tuple[str, str]], *, header: bool = False)
101101
]
102102

103103
options = mypy.options.Options()
104-
options.incremental = False
104+
options.show_traceback = True
105+
options.incremental = True
105106
options.show_traceback = True
106107
options.preserve_asts = True
107108
options.disallow_untyped_defs = False
@@ -159,8 +160,7 @@ def parse_multiple(modules: Sequence[tuple[str, str]], *, header: bool = False)
159160

160161

161162
@functools.lru_cache(maxsize=1)
162-
def parse(code: str, *, header: bool = True) -> ParseResult:
163-
module_name = "test_module"
163+
def parse(code: str, *, header: bool = True, module_name: str = "test_module") -> ParseResult:
164164
parse_result = parse_multiple([(module_name, code)], header=header)
165165
return parse_result.single(module_name)
166166

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "mypy-pytest-plugin"
33
requires-python = ">=3.12,<3.15"
4-
version = "0.7.2"
4+
version = "0.8.0"
55
dynamic = ["dependencies"]
66

77
[tool.setuptools]
@@ -54,3 +54,4 @@ python_files = "*_test.py"
5454
addopts = '--capture=sys --timeout 60 --timeout-method=thread'
5555
norecursedirs = ["test_samples"]
5656
pythonpath = ["."]
57+
markers = ["used_for_testing: used to verify config in mark attribute tests"]

0 commit comments

Comments
 (0)