Skip to content

Commit 8b4216a

Browse files
authored
Merge pull request #23 from George-Ogden/mock-patch
Typed Patching
2 parents 3bc1551 + 78d7182 commit 8b4216a

31 files changed

+1631
-121
lines changed

.github/workflows/test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ jobs:
1717
uses: George-Ogden/actions/.github/workflows/python-test.yaml@v3.1.3
1818
with:
1919
python-versions: "['3.13']"
20-
timeout-minutes: 10
20+
timeout-minutes: 15
2121
pytest-flags: -vv mypy_pytest_plugin -n auto --shard-id=${{ matrix.shard }} --num-shards=${{ matrix.total-shards }}
2222

2323
integration_tests:
2424
uses: George-Ogden/actions/.github/workflows/python-test.yaml@v3.1.3
2525
with:
2626
python-versions: "['3.13']"
27-
timeout-minutes: 5
27+
timeout-minutes: 10
2828
pytest-flags: -vv tests

mypy_pytest_plugin/argmapper.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from collections.abc import Collection
2+
import functools
3+
from typing import Final
4+
5+
from mypy.argmap import map_formals_to_actuals
6+
from mypy.checker import TypeChecker
7+
from mypy.nodes import ArgKind, CallExpr, Expression
8+
from mypy.typeops import bind_self
9+
from mypy.types import CallableType, FunctionLike, Instance, Overloaded, Type
10+
11+
type ArgMap = dict[str, Expression]
12+
13+
14+
class ArgMapper:
15+
ACCEPTED_ARG_KINDS: Final[Collection[ArgKind]] = (ArgKind.ARG_POS, ArgKind.ARG_NAMED)
16+
17+
@classmethod
18+
def named_arg_mapping(cls, call: CallExpr, checker: TypeChecker) -> ArgMap:
19+
callee_type = checker.lookup_type(call.callee)
20+
return cls._named_arg_type_mapping(call, callee_type, checker)
21+
22+
@classmethod
23+
def _named_arg_type_mapping(
24+
cls, call: CallExpr, callee_type: Type, checker: TypeChecker
25+
) -> ArgMap:
26+
if isinstance(callee_type, CallableType):
27+
return cls._named_arg_callable_mapping(call, callee_type, checker)
28+
if isinstance(callee_type, Overloaded):
29+
return cls._named_arg_overloaded_mapping(call, callee_type, checker)
30+
if (
31+
isinstance(callee_type, Instance)
32+
and (call_node := callee_type.type.names.get("__call__")) is not None
33+
and call_node.type is not None
34+
):
35+
type_ = call_node.type
36+
if isinstance(type_, FunctionLike):
37+
type_ = bind_self(type_, callee_type)
38+
return cls._named_arg_type_mapping(call, type_, checker)
39+
return {}
40+
41+
@classmethod
42+
def _named_arg_callable_mapping(
43+
cls, call: CallExpr, callee_type: CallableType, checker: TypeChecker
44+
) -> ArgMap:
45+
mapping = map_formals_to_actuals(
46+
actual_kinds=call.arg_kinds,
47+
actual_names=call.arg_names,
48+
formal_kinds=callee_type.arg_kinds,
49+
formal_names=callee_type.arg_names,
50+
actual_arg_type=lambda i: call.args[i].accept(checker.expr_checker),
51+
)
52+
53+
return {
54+
arg_name: call.args[actual_idx]
55+
for actual_idx, formal_idxs in enumerate(mapping)
56+
if len(formal_idxs) == 1
57+
and callee_type.arg_kinds[formal_idx := formal_idxs[0]] in cls.ACCEPTED_ARG_KINDS
58+
and formal_idx < len(callee_type.arg_names)
59+
and (arg_name := callee_type.arg_names[formal_idx]) is not None
60+
}
61+
62+
@classmethod
63+
def _named_arg_overloaded_mapping(
64+
cls, call: CallExpr, callee_type: Overloaded, checker: TypeChecker
65+
) -> ArgMap:
66+
return functools.reduce(
67+
cls._merge_mappings,
68+
(
69+
cls._named_arg_callable_mapping(call, callable_type, checker)
70+
for callable_type in callee_type.items
71+
),
72+
)
73+
74+
@classmethod
75+
def _merge_mappings(cls, this: ArgMap, that: ArgMap) -> ArgMap:
76+
return {key: expr for key, expr in this.items() if that.get(key, None) is expr}
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from typing import Any, cast
2+
3+
from mypy.nodes import CallExpr, Expression
4+
5+
from .argmapper import ArgMapper
6+
from .test_utils import dump_expr, parse
7+
8+
9+
def _named_arg_mapping_test_body(defs: str, expected_keys: list[str]) -> None:
10+
parse_result = parse(defs)
11+
parse_result.accept_all()
12+
13+
call = parse_result.defs["call"]
14+
assert isinstance(call, CallExpr)
15+
16+
raw_arg_map = ArgMapper.named_arg_mapping(call, parse_result.checker)
17+
18+
def dump_arg_map(
19+
arg_map: dict[str, Expression],
20+
) -> dict[str, tuple[type, dict[str, Any]]]:
21+
return {key: (dump_expr(expr)) for key, expr in arg_map.items()}
22+
23+
assert dump_arg_map(raw_arg_map) == dump_arg_map(
24+
{key: cast(Expression, parse_result.defs[key]) for key in expected_keys}
25+
)
26+
27+
28+
def test_named_arg_mapping_no_args() -> None:
29+
_named_arg_mapping_test_body(
30+
"""
31+
def main() -> int:
32+
return 0
33+
34+
call = main()
35+
""",
36+
[],
37+
)
38+
39+
40+
def test_named_arg_mapping_no_named_args() -> None:
41+
_named_arg_mapping_test_body(
42+
"""
43+
def main(x: int, y: str, /, *args: bool) -> int:
44+
return 0
45+
46+
call = main(3, "2", True, False)
47+
""",
48+
[],
49+
)
50+
51+
52+
def test_named_arg_mapping_named_args_only() -> None:
53+
_named_arg_mapping_test_body(
54+
"""
55+
def main(x: int, y: bool, *, z: str) -> int:
56+
return 0
57+
58+
call = main(0, z="1", y=False)
59+
x = 0
60+
y = False
61+
z = "1"
62+
""",
63+
["x", "y", "z"],
64+
)
65+
66+
67+
def test_named_arg_mapping_arg_mix() -> None:
68+
_named_arg_mapping_test_body(
69+
"""
70+
from typing import Any
71+
72+
def main(a: int, /, b: bool, *args: Any, c: float, **kwargs: dict) -> int:
73+
return 0
74+
75+
call = main(0, True, 2j, c=3.0, d={}, e=dict())
76+
b = True
77+
c = 3.0
78+
""",
79+
["b", "c"],
80+
)
81+
82+
83+
def test_named_arg_mapping_simple_overload() -> None:
84+
_named_arg_mapping_test_body(
85+
"""
86+
from typing import overload
87+
88+
@overload
89+
def foo(x: int) -> int:
90+
...
91+
92+
@overload
93+
def foo(x: str) -> str:
94+
...
95+
96+
def foo(x: str | int) -> str | int:
97+
return x
98+
99+
call = foo(x="str")
100+
x = "str"
101+
""",
102+
["x"],
103+
)
104+
105+
106+
def test_named_arg_mapping_complex_overload() -> None:
107+
_named_arg_mapping_test_body(
108+
"""
109+
from typing import overload
110+
111+
@overload
112+
def foo(x: int, y: str, *, z: bool) -> int:
113+
...
114+
115+
@overload
116+
def foo(x: int, y: str) -> int:
117+
...
118+
119+
def foo(x: int, y: str, **kwargs: bool) -> int:
120+
return 0
121+
122+
call = foo(y="0", z=False, x=2)
123+
x = 2
124+
y = "0"
125+
""",
126+
["x", "y"],
127+
)
128+
129+
130+
def test_named_arg_mapping_varargs_varkwargs_overload() -> None:
131+
_named_arg_mapping_test_body(
132+
"""
133+
from typing import overload
134+
135+
@overload
136+
def foo(x: int, y: str, **kwargs: bool) -> int:
137+
...
138+
139+
@overload
140+
def foo(x: int, y: str, **kwargs: int) -> int:
141+
...
142+
143+
def foo(x: int, y: str, **kwargs: bool | int) -> int:
144+
return 0
145+
146+
call = foo(*(1, "2"), z=True)
147+
""",
148+
[],
149+
)
150+
151+
152+
def test_named_arg_mapping_instance_method() -> None:
153+
_named_arg_mapping_test_body(
154+
"""
155+
from typing import overload
156+
157+
class Foo:
158+
def bar(self, x: str, y: str, *args: str) -> str:
159+
return x + y
160+
161+
foo = Foo()
162+
call = foo.bar("a", "b", "c")
163+
x = "a"
164+
y = "b"
165+
""",
166+
["x", "y"],
167+
)
168+
169+
170+
def test_named_arg_mapping_call_method() -> None:
171+
_named_arg_mapping_test_body(
172+
"""
173+
from typing import overload
174+
175+
class Foo:
176+
def __call__(self, x: str, y: str, *args: str) -> str:
177+
return x + y
178+
179+
foo = Foo()
180+
call = foo("a", "b", "c")
181+
x = "a"
182+
y = "b"
183+
""",
184+
["x", "y"],
185+
)
Lines changed: 15 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from collections.abc import Collection, Sequence
1+
from collections.abc import Sequence
22
from dataclasses import dataclass
33
from typing import Self, TypeGuard
44

5-
from mypy.argmap import map_actuals_to_formals
65
from mypy.checker import TypeChecker
7-
from mypy.nodes import ArgKind, CallExpr, Expression
8-
from mypy.types import CallableType, Instance, Type
6+
from mypy.nodes import CallExpr, Expression
7+
from mypy.types import Instance
98

10-
from .error_codes import VARIADIC_ARGNAMES_ARGVALUES
9+
from .argmapper import ArgMapper
10+
from .error_codes import UNREADABLE_ARGNAMES_ARGVALUES
1111

1212

1313
@dataclass(frozen=True, slots=True)
@@ -37,54 +37,15 @@ def _is_parametrized_decorator_node(
3737
)
3838
return False
3939

40-
def _get_arg_type(self, i: int) -> Type:
41-
# subtract one for self
42-
i -= 1
43-
return self.call.args[i].accept(self.checker.expr_checker)
44-
4540
@property
4641
def arg_names_and_arg_values(self) -> tuple[Expression, Expression] | None:
47-
mapping = map_actuals_to_formals(
48-
actual_kinds=[ArgKind.ARG_POS, *self.call.arg_kinds],
49-
actual_names=[None, *self.call.arg_names],
50-
formal_kinds=self.fn_type.arg_kinds,
51-
formal_names=self.fn_type.arg_names,
52-
actual_arg_type=self._get_arg_type,
53-
)
54-
return self._check_actuals_formals_mapping(mapping)
55-
56-
@property
57-
def fn_type(self) -> CallableType:
58-
callee_type = self.call.callee.accept(self.checker.expr_checker)
59-
assert isinstance(callee_type, Instance)
60-
fn_type = callee_type.type.names["__call__"].type
61-
assert isinstance(fn_type, CallableType)
62-
return fn_type
63-
64-
def _check_actuals_formals_mapping(
65-
self, mapping: list[list[int]]
66-
) -> tuple[Expression, Expression] | None:
67-
arg_names_idx, arg_values_idx, *_ = self._clean_up_actuals_formals_mapping(mapping)
68-
if (
69-
self.call.arg_kinds[arg_values_idx] in self.accepted_arg_kinds
70-
and self.call.arg_kinds[arg_names_idx] in self.accepted_arg_kinds
71-
):
72-
return self.call.args[arg_names_idx], self.call.args[arg_values_idx]
73-
self.checker.fail(
74-
"Unable to read argnames and argvalues in a variadic argument.",
75-
context=self.call,
76-
code=VARIADIC_ARGNAMES_ARGVALUES,
77-
)
78-
return None
79-
80-
def _clean_up_actuals_formals_mapping(
81-
self, mapping: list[list[int]]
82-
) -> tuple[int, int, list[list[int]]]:
83-
[_, [arg_names_idx], [arg_values_idx], *extras] = mapping
84-
arg_values_idx -= 1
85-
arg_names_idx -= 1
86-
return arg_names_idx, arg_values_idx, extras
87-
88-
@property
89-
def accepted_arg_kinds(self) -> Collection[ArgKind]:
90-
return (ArgKind.ARG_POS, ArgKind.ARG_NAMED)
42+
name_mapping = ArgMapper.named_arg_mapping(self.call, self.checker)
43+
try:
44+
return name_mapping["argnames"], name_mapping["argvalues"]
45+
except KeyError:
46+
self.checker.fail(
47+
"Unable to read argnames and argvalues. Use positional or keyword arguments.",
48+
context=self.call,
49+
code=UNREADABLE_ARGNAMES_ARGVALUES,
50+
)
51+
return None

mypy_pytest_plugin/error_codes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
"missing-argname", "Argument not included in Pytest parametrization.", category="Pytest"
4040
)
4141

42-
VARIADIC_ARGNAMES_ARGVALUES: Final[ErrorCode] = ErrorCode(
43-
"variadic-argnames-argvalues",
44-
"Unable to parse variadic argnames or argvalues.",
42+
UNREADABLE_ARGNAMES_ARGVALUES: Final[ErrorCode] = ErrorCode(
43+
"unreadable-argnames-argvalues",
44+
"Unable to read argnames or argvalues.",
4545
category="Pytest",
4646
)
4747

mypy_pytest_plugin/excluded_test_checker_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ def test_6() -> None: ...
4141
"""
4242
)
4343

44-
for statement in parse_result.raw_defs:
45-
statement.accept(parse_result.checker)
44+
parse_result.accept_all()
4645
ignored_tests = ExcludedTestChecker.ignored_test_names(
4746
parse_result.raw_defs, parse_result.checker
4847
)

mypy_pytest_plugin/fixture_manager_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ def _fixture_manager_resolve_requests_and_fixtures_test_body(
9494
checker = parse_result.checkers[last_module_name]
9595
fixture_def = parse_result.defs[fullname]
9696
assert isinstance(fixture_def, FuncDef)
97-
for def_ in parse_result.raw_defs:
98-
def_.accept(checker)
97+
parse_result.checker_accept_all(checker)
9998

10099
start = TestArgument.from_fn_def(fixture_def, checker=checker, source="test")
101100
assert start is not None

0 commit comments

Comments
 (0)