Skip to content

Commit a1ae191

Browse files
Fix unittest.mock.patch and unittest.mock.patch.object when new_callable is not None (#14358)
1 parent 3f0dce5 commit a1ae191

File tree

2 files changed

+89
-9
lines changed

2 files changed

+89
-9
lines changed

stdlib/@tests/test_cases/check_unittest.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from datetime import datetime, timedelta
66
from decimal import Decimal
77
from fractions import Fraction
8-
from typing import TypedDict
8+
from typing import TypedDict, Union
99
from typing_extensions import assert_type
10-
from unittest.mock import MagicMock, Mock, patch
10+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1111

1212
case = unittest.TestCase()
1313

@@ -154,10 +154,17 @@ def f_explicit_new(i: int) -> str:
154154
return "asdf"
155155

156156

157+
@patch("sys.exit", new_callable=lambda: 42)
158+
def f_explicit_new_callable(i: int, new_callable_ret: int) -> str:
159+
return "asdf"
160+
161+
157162
assert_type(f_default_new(1), str)
158163
f_default_new("a") # Not an error due to ParamSpec limitations
159164
assert_type(f_explicit_new(1), str)
160165
f_explicit_new("a") # type: ignore[arg-type]
166+
assert_type(f_explicit_new_callable(1), str)
167+
f_explicit_new_callable("a") # Same as default new
161168

162169

163170
@patch("sys.exit", new=Mock())
@@ -171,3 +178,51 @@ def method() -> int:
171178

172179
assert_type(TestXYZ.attr, int)
173180
assert_type(TestXYZ.method(), int)
181+
182+
183+
with patch("sys.exit") as default_new_enter:
184+
assert_type(default_new_enter, Union[MagicMock, AsyncMock])
185+
186+
with patch("sys.exit", new=42) as explicit_new_enter:
187+
assert_type(explicit_new_enter, int)
188+
189+
with patch("sys.exit", new_callable=lambda: 42) as explicit_new_callable_enter:
190+
assert_type(explicit_new_callable_enter, int)
191+
192+
193+
###
194+
# Tests for mock.patch.object
195+
###
196+
197+
198+
@patch.object(Decimal, "exp")
199+
def obj_f_default_new(i: int, mock: MagicMock) -> str:
200+
return "asdf"
201+
202+
203+
@patch.object(Decimal, "exp", new=42)
204+
def obj_f_explicit_new(i: int) -> str:
205+
return "asdf"
206+
207+
208+
@patch.object(Decimal, "exp", new_callable=lambda: 42)
209+
def obj_f_explicit_new_callable(i: int, new_callable_ret: int) -> str:
210+
return "asdf"
211+
212+
213+
assert_type(obj_f_default_new(1), str)
214+
obj_f_default_new("a") # Not an error due to ParamSpec limitations
215+
assert_type(obj_f_explicit_new(1), str)
216+
obj_f_explicit_new("a") # type: ignore[arg-type]
217+
assert_type(obj_f_explicit_new_callable(1), str)
218+
obj_f_explicit_new_callable("a") # Same as default new
219+
220+
221+
with patch.object(Decimal, "exp") as obj_default_new_enter:
222+
assert_type(obj_default_new_enter, Union[MagicMock, AsyncMock])
223+
224+
with patch.object(Decimal, "exp", new=42) as obj_explicit_new_enter:
225+
assert_type(obj_explicit_new_enter, int)
226+
227+
with patch.object(Decimal, "exp", new_callable=lambda: 42) as obj_explicit_new_callable_enter:
228+
assert_type(obj_explicit_new_callable_enter, int)

stdlib/unittest/mock.pyi

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class _patch(Generic[_T]):
262262
# This class does not exist at runtime, it's a hack to make this work:
263263
# @patch("foo")
264264
# def bar(..., mock: MagicMock) -> None: ...
265-
class _patch_default_new(_patch[MagicMock | AsyncMock]):
265+
class _patch_pass_arg(_patch[_T]):
266266
@overload
267267
def __call__(self, func: _TT) -> _TT: ...
268268
# Can't use the following as ParamSpec is only allowed as last parameter:
@@ -303,7 +303,7 @@ class _patcher:
303303
create: bool = ...,
304304
spec_set: Any | None = ...,
305305
autospec: Any | None = ...,
306-
new_callable: Any | None = ...,
306+
new_callable: Callable[..., Any] | None = ...,
307307
**kwargs: Any,
308308
) -> _patch[_T]: ...
309309
@overload
@@ -315,9 +315,21 @@ class _patcher:
315315
create: bool = ...,
316316
spec_set: Any | None = ...,
317317
autospec: Any | None = ...,
318-
new_callable: Any | None = ...,
318+
new_callable: Callable[..., _T],
319319
**kwargs: Any,
320-
) -> _patch_default_new: ...
320+
) -> _patch_pass_arg[_T]: ...
321+
@overload
322+
def __call__(
323+
self,
324+
target: str,
325+
*,
326+
spec: Any | None = ...,
327+
create: bool = ...,
328+
spec_set: Any | None = ...,
329+
autospec: Any | None = ...,
330+
new_callable: None = ...,
331+
**kwargs: Any,
332+
) -> _patch_pass_arg[MagicMock | AsyncMock]: ...
321333
@overload
322334
@staticmethod
323335
def object(
@@ -328,7 +340,7 @@ class _patcher:
328340
create: bool = ...,
329341
spec_set: Any | None = ...,
330342
autospec: Any | None = ...,
331-
new_callable: Any | None = ...,
343+
new_callable: Callable[..., Any] | None = ...,
332344
**kwargs: Any,
333345
) -> _patch[_T]: ...
334346
@overload
@@ -341,9 +353,22 @@ class _patcher:
341353
create: bool = ...,
342354
spec_set: Any | None = ...,
343355
autospec: Any | None = ...,
344-
new_callable: Any | None = ...,
356+
new_callable: Callable[..., _T],
357+
**kwargs: Any,
358+
) -> _patch_pass_arg[_T]: ...
359+
@overload
360+
@staticmethod
361+
def object(
362+
target: Any,
363+
attribute: str,
364+
*,
365+
spec: Any | None = ...,
366+
create: bool = ...,
367+
spec_set: Any | None = ...,
368+
autospec: Any | None = ...,
369+
new_callable: None = ...,
345370
**kwargs: Any,
346-
) -> _patch[MagicMock | AsyncMock]: ...
371+
) -> _patch_pass_arg[MagicMock | AsyncMock]: ...
347372
@staticmethod
348373
def multiple(
349374
target: Any,

0 commit comments

Comments
 (0)