diff --git a/CHANGES/11876.misc.rst b/CHANGES/11876.misc.rst new file mode 100644 index 00000000000..ab5dc9cfa7e --- /dev/null +++ b/CHANGES/11876.misc.rst @@ -0,0 +1 @@ +Refactored tests to use ``create_autospec()`` for more robust mocking -- by :user:`soheil-star01`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 7f81d3e5dd6..3c0c4cf71fb 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -337,6 +337,7 @@ Serhiy Storchaka Shubh Agarwal Simon Kennedy Sin-Woo Bang +Soheil Dolatabadi Stanislas Plum Stanislav Prokop Stefan Tjarks diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 74670cbc9f7..662d9261e92 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -981,7 +981,10 @@ async def test_content_encoding( # type: ignore[misc] "post", URL("http://python.org/"), data="foo", compress="deflate", loop=loop ) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: - m_writer.return_value.write_headers = mock.AsyncMock() + mock_writer_instance = mock.create_autospec( + StreamWriter, instance=True, spec_set=True + ) + m_writer.return_value = mock_writer_instance resp = await req._send(conn) assert req.headers["TRANSFER-ENCODING"] == "chunked" assert req.headers["CONTENT-ENCODING"] == "deflate" @@ -1019,7 +1022,10 @@ async def test_content_encoding_header( # type: ignore[misc] loop=loop, ) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: - m_writer.return_value.write_headers = mock.AsyncMock() + mock_writer_instance = mock.create_autospec( + StreamWriter, instance=True, spec_set=True + ) + m_writer.return_value = mock_writer_instance resp = await req._send(conn) assert not m_writer.return_value.enable_compression.called @@ -1108,8 +1114,10 @@ async def test_chunked_explicit( "post", URL("http://python.org/"), chunked=True, loop=loop ) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: - m_writer.return_value.write_headers = mock.AsyncMock() - m_writer.return_value.write_eof = mock.AsyncMock() + mock_writer_instance = mock.create_autospec( + StreamWriter, instance=True, spec_set=True + ) + m_writer.return_value = mock_writer_instance resp = await req._send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] @@ -1977,8 +1985,7 @@ async def test_update_body_closes_previous_payload( req = make_client_request("POST", URL("http://python.org/")) # Create a mock payload that tracks if it was closed - mock_payload = mock.Mock(spec=payload.Payload) - mock_payload.close = mock.AsyncMock() + mock_payload = mock.create_autospec(payload.Payload, spec_set=True, instance=True) # Set initial payload req._body = mock_payload @@ -2101,7 +2108,7 @@ async def test_expect100_with_body_becomes_empty( ) -> None: """Test that write_bytes handles body becoming empty after expect100 handling.""" # Create a mock writer and connection - mock_writer = mock.AsyncMock() + mock_writer = mock.create_autospec(StreamWriter, instance=True, spec_set=True) mock_conn = mock.Mock() # Create a request diff --git a/tests/test_client_response.py b/tests/test_client_response.py index e9f7da127ee..9d7025efaca 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -18,6 +18,7 @@ from aiohttp.connector import Connection from aiohttp.helpers import TimerNoop from aiohttp.multipart import BadContentDispositionHeader +from aiohttp.tracing import Trace class WriterMock(mock.AsyncMock): @@ -1262,8 +1263,7 @@ def test_redirect_history_in_exception() -> None: async def test_response_read_triggers_callback( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: - trace = mock.Mock() - trace.send_response_chunk_received = mock.AsyncMock() + trace = mock.create_autospec(Trace, instance=True, spec_set=True) response_method = "get" response_url = URL("http://def-cl-resp.org") response_body = b"This is response" diff --git a/tests/test_client_session.py b/tests/test_client_session.py index e9106c3443d..df85456d2c1 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -8,6 +8,7 @@ from collections import deque from collections.abc import Awaitable, Callable, Iterator from http.cookies import BaseCookie, SimpleCookie +from types import SimpleNamespace from typing import Any, NoReturn, TypedDict, cast from unittest import mock from uuid import uuid4 @@ -25,9 +26,19 @@ from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector from aiohttp.cookiejar import CookieJar from aiohttp.http import RawResponseMessage +from aiohttp.payload import Payload from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.test_utils import TestServer -from aiohttp.tracing import Trace +from aiohttp.tracing import ( + Trace, + TraceRequestChunkSentParams, + TraceRequestEndParams, + TraceRequestExceptionParams, + TraceRequestHeadersSentParams, + TraceRequestRedirectParams, + TraceRequestStartParams, + TraceResponseChunkReceivedParams, +) class _Params(TypedDict): @@ -557,9 +568,8 @@ async def test_reraise_os_error( err = OSError(1, "permission error") req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req_factory = mock.Mock(return_value=req) - req._send = mock.AsyncMock(side_effect=err) - req._body = mock.Mock() - req._body.close = mock.AsyncMock() + req._send.side_effect = err + req._body = mock.create_autospec(Payload, spec_set=True, instance=True) session = await create_session(request_class=req_factory) async def create_connection( @@ -589,9 +599,8 @@ class UnexpectedException(BaseException): err = UnexpectedException("permission error") req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req_factory = mock.Mock(return_value=req) - req._send = mock.AsyncMock(side_effect=err) - req._body = mock.Mock() - req._body.close = mock.AsyncMock() + req._send.side_effect = err + req._body = mock.create_autospec(Payload, spec_set=True, instance=True) session = await create_session(request_class=req_factory) connections = [] @@ -637,7 +646,7 @@ async def test_ws_connect_allowed_protocols( # type: ignore[misc] ws_key: str, key_data: bytes, ) -> None: - resp = mock.create_autospec(aiohttp.ClientResponse) + resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True, instance=True) resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", @@ -646,7 +655,6 @@ async def test_ws_connect_allowed_protocols( # type: ignore[misc] } resp.url = URL(f"{protocol}://example") resp.cookies = SimpleCookie() - resp.start = mock.AsyncMock() req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req._body = None # No body for WebSocket upgrade requests @@ -700,7 +708,7 @@ async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc] ws_key: str, key_data: bytes, ) -> None: - resp = mock.create_autospec(aiohttp.ClientResponse) + resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True, instance=True) resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", @@ -709,7 +717,6 @@ async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc] } resp.url = URL(f"{protocol}://example") resp.cookies = SimpleCookie() - resp.start = mock.AsyncMock() req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req._body = None # No body for WebSocket upgrade requests @@ -923,15 +930,41 @@ async def test_request_tracing( async def handler(request: web.Request) -> web.Response: return web.json_response({"ok": True}) + # Define callback signatures + async def on_request_start_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestStartParams, + ) -> None: + pass + + async def on_request_end_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestEndParams, + ) -> None: + pass + + async def on_request_redirect_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestRedirectParams, + ) -> None: + pass + app = web.Application() app.router.add_post("/", handler) trace_config_ctx = mock.Mock() body = "This is request body" gathered_req_headers: CIMultiDict[str] = CIMultiDict() - on_request_start = mock.AsyncMock() - on_request_redirect = mock.AsyncMock() - on_request_end = mock.AsyncMock() + + # Create mocks with signatures(above) + on_request_start = mock.create_autospec(on_request_start_callback, spec_set=True) + on_request_end = mock.create_autospec(on_request_end_callback, spec_set=True) + on_request_redirect = mock.create_autospec( + on_request_redirect_callback, spec_set=True + ) with io.BytesIO() as gathered_req_body, io.BytesIO() as gathered_res_body: @@ -1006,12 +1039,78 @@ async def root_handler(request: web.Request) -> web.Response: async def redirect_handler(request: web.Request) -> NoReturn: raise web.HTTPFound("/") + # Define callback signatures + async def on_request_start_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestStartParams, + ) -> None: + pass + + async def on_request_end_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestEndParams, + ) -> None: + pass + + async def on_request_redirect_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestRedirectParams, + ) -> None: + pass + + async def on_request_exception_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestExceptionParams, + ) -> None: + pass + + async def on_request_chunk_sent_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestChunkSentParams, + ) -> None: + pass + + async def on_response_chunk_received_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceResponseChunkReceivedParams, + ) -> None: + pass + + async def on_request_headers_sent_callback( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestHeadersSentParams, + ) -> None: + pass + app = web.Application() app.router.add_get("/", root_handler) app.router.add_get("/redirect", redirect_handler) - mocks = [mock.AsyncMock() for _ in range(7)] - ( + on_request_start = mock.create_autospec(on_request_start_callback, spec_set=True) + on_request_redirect = mock.create_autospec( + on_request_redirect_callback, spec_set=True + ) + on_request_end = mock.create_autospec(on_request_end_callback, spec_set=True) + on_request_exception = mock.create_autospec( + on_request_exception_callback, spec_set=True + ) + on_request_chunk_sent = mock.create_autospec( + on_request_chunk_sent_callback, spec_set=True + ) + on_response_chunk_received = mock.create_autospec( + on_response_chunk_received_callback, spec_set=True + ) + on_request_headers_sent = mock.create_autospec( + on_request_headers_sent_callback, spec_set=True + ) + mocks = [ on_request_start, on_request_redirect, on_request_end, @@ -1019,7 +1118,7 @@ async def redirect_handler(request: web.Request) -> NoReturn: on_request_chunk_sent, on_response_chunk_received, on_request_headers_sent, - ) = mocks + ] trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=mock.Mock()) diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 9e1d8457586..655240d3eaa 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -15,6 +15,7 @@ client, hdrs, ) +from aiohttp._websocket.writer import WebSocketWriter as RealWebSocketWriter from aiohttp.http import WS_KEY from aiohttp.http_websocket import WSMessageClose from aiohttp.streams import EofStream @@ -380,9 +381,10 @@ async def test_close( m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) - writer = mock.Mock() + writer = mock.create_autospec( + RealWebSocketWriter, instance=True, spec_set=True + ) WebSocketWriter.return_value = writer - writer.close = mock.AsyncMock() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") @@ -489,9 +491,10 @@ async def test_close_exc( m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) - writer = mock.Mock() + writer = mock.create_autospec( + RealWebSocketWriter, instance=True, spec_set=True + ) WebSocketWriter.return_value = writer - writer.close = mock.AsyncMock() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") @@ -625,9 +628,10 @@ async def test_reader_read_exception( m_req.return_value = loop.create_future() m_req.return_value.set_result(hresp) - writer = mock.Mock() + writer = mock.create_autospec( + RealWebSocketWriter, instance=True, spec_set=True + ) WebSocketWriter.return_value = writer - writer.close = mock.AsyncMock() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") @@ -778,29 +782,34 @@ async def test_ws_connect_deflate_per_message( m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) - writer = WebSocketWriter.return_value = mock.Mock() - send_frame = writer.send_frame = mock.AsyncMock() + writer = mock.create_autospec( + RealWebSocketWriter, instance=True, spec_set=True + ) + + WebSocketWriter.return_value = writer session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") await resp.send_str("string", compress=-1) - send_frame.assert_called_with( + writer.send_frame.assert_called_with( b"string", aiohttp.WSMsgType.TEXT, compress=-1 ) await resp.send_bytes(b"bytes", compress=15) - send_frame.assert_called_with( + writer.send_frame.assert_called_with( b"bytes", aiohttp.WSMsgType.BINARY, compress=15 ) await resp.send_json([{}], compress=-9) - send_frame.assert_called_with( + writer.send_frame.assert_called_with( b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9 ) await resp.send_frame(b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9) - send_frame.assert_called_with(b"[{}]", aiohttp.WSMsgType.TEXT, -9) + writer.send_frame.assert_called_with( + b"[{}]", aiohttp.WSMsgType.TEXT, -9 + ) await session.close() diff --git a/tests/test_connector.py b/tests/test_connector.py index ae5e2e068b0..f73862e4d4c 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -30,7 +30,7 @@ hdrs, web, ) -from aiohttp.abc import ResolveResult +from aiohttp.abc import AbstractResolver, ResolveResult from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequestArgs, ConnectionKey from aiohttp.connector import ( @@ -43,6 +43,7 @@ _DNSCacheTable, ) from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer +from aiohttp.resolver import AsyncResolver from aiohttp.test_utils import unused_port from aiohttp.tracing import Trace @@ -1311,9 +1312,14 @@ async def test_tcp_connector_dns_cache_not_expired( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + + mock_default_resolver = mock.create_autospec( + AsyncResolver, instance=True, spec_set=True + ) + mock_default_resolver.resolve.return_value = await dns_response() + m_resolver.return_value = mock_default_resolver + conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) - m_resolver().resolve.return_value = dns_response() - m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) @@ -1325,12 +1331,18 @@ async def test_tcp_connector_dns_cache_forever( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + mock_default_resolver = mock.create_autospec( + AsyncResolver, instance=True, spec_set=True + ) + mock_default_resolver.resolve.return_value = await dns_response() + m_resolver.return_value = mock_default_resolver + conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) - m_resolver().resolve.return_value = dns_response() - m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) - m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) + mock_default_resolver.resolve.assert_called_once_with( + "localhost", 8080, family=0 + ) await conn.close() @@ -1339,12 +1351,21 @@ async def test_tcp_connector_use_dns_cache_disabled( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + + mock_default_resolver = mock.create_autospec( + AsyncResolver, instance=True, spec_set=True + ) + mock_default_resolver.resolve.side_effect = [ + await dns_response(), + await dns_response(), + ] + m_resolver.return_value = mock_default_resolver + conn = aiohttp.TCPConnector(use_dns_cache=False) - m_resolver().resolve.side_effect = [dns_response(), dns_response()] - m_resolver().close = mock.AsyncMock() + await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) - m_resolver().resolve.assert_has_calls( + mock_default_resolver.resolve.assert_has_calls( [ mock.call("localhost", 8080, family=0), mock.call("localhost", 8080, family=0), @@ -1358,14 +1379,25 @@ async def test_tcp_connector_dns_throttle_requests( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + mock_default_resolver = mock.create_autospec( + AbstractResolver, instance=True, spec_set=True + ) + + async def mock_resolve(*_args: object, **_kwargs: object) -> list[str]: + return await dns_response() + + mock_default_resolver.resolve.side_effect = mock_resolve + m_resolver.return_value = mock_default_resolver + conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) - m_resolver().resolve.return_value = dns_response() - m_resolver().close = mock.AsyncMock() + t = loop.create_task(conn._resolve_host("localhost", 8080)) t2 = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) await asyncio.sleep(0) - m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) + mock_default_resolver.resolve.assert_called_once_with( + "localhost", 8080, family=0 + ) t.cancel() t2.cancel() with pytest.raises(asyncio.CancelledError): @@ -1378,10 +1410,16 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread( loop: asyncio.AbstractEventLoop, ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: - conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) e = Exception() - m_resolver().resolve.side_effect = e - m_resolver().close = mock.AsyncMock() + + mock_resolver_instance = mock.create_autospec( + AbstractResolver, instance=True, spec_set=True + ) + mock_resolver_instance.resolve.side_effect = e + m_resolver.return_value = mock_resolver_instance + + conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) + r1 = loop.create_task(conn._resolve_host("localhost", 8080)) r2 = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) @@ -1398,9 +1436,16 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + + async def mock_resolve(*_args: object, **_kwargs: object) -> list[str]: + return await dns_response() + + mock_default_resolver = mock.create_autospec( + AbstractResolver, instance=True, spec_set=True + ) + mock_default_resolver.resolve.side_effect = mock_resolve + m_resolver.return_value = mock_default_resolver conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) - m_resolver().resolve.return_value = dns_response() - m_resolver().close = mock.AsyncMock() t = loop.create_task(conn._resolve_host("localhost", 8080)) f = loop.create_task(conn._resolve_host("localhost", 8080)) @@ -1605,7 +1650,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests( async def test_tcp_connector_close_resolver() -> None: - m_resolver = mock.AsyncMock() + m_resolver = mock.create_autospec(AbstractResolver, instance=True, spec_set=True) with mock.patch("aiohttp.connector.DefaultResolver", return_value=m_resolver): conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) await conn.close() @@ -4368,13 +4413,18 @@ async def send_dns_cache_miss(self, *args: object, **kwargs: object) -> None: if request_count <= 1: raise Exception("first attempt") - async def resolve_response() -> list[ResolveResult]: + async def resolve_response( + *_args: object, **_kwargs: object + ) -> list[ResolveResult]: await asyncio.sleep(0) return [token] with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: - m_resolver().resolve.return_value = resolve_response() - m_resolver().close = mock.AsyncMock() + mock_default_resolver = mock.create_autospec( + AsyncResolver, instance=True, spec_set=True + ) + mock_default_resolver.resolve.side_effect = resolve_response + m_resolver.return_value = mock_default_resolver connector = TCPConnector() traces = [DummyTracer()] diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index e6b0bac97a1..77ef001fa02 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -876,7 +876,12 @@ async def test_write_calls_callback( transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: - on_chunk_sent = mock.AsyncMock() + + async def on_chunk_sent(chunk: bytes) -> None: + # define callback signature + pass + + on_chunk_sent = mock.create_autospec(on_chunk_sent, spec_set=True) msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent) chunk = b"1" await msg.write(chunk) @@ -889,7 +894,13 @@ async def test_write_eof_calls_callback( transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: - on_chunk_sent = mock.AsyncMock() + + async def on_chunk_sent(chunk: bytes) -> None: + # define callback signature + pass + + on_chunk_sent = mock.create_autospec(on_chunk_sent, spec_set=True) + msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent) chunk = b"1" await msg.write_eof(chunk=chunk) diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 32975f562d1..bc729506e0d 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -1262,9 +1262,7 @@ async def test_send_set_cookie_header( async def test_consecutive_write_eof() -> None: - writer = mock.Mock() - writer.write_eof = mock.AsyncMock() - writer.write_headers = mock.AsyncMock() + writer = mock.create_autospec(AbstractStreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", writer=writer) data = b"data" resp = web.Response(body=data) @@ -1515,10 +1513,7 @@ async def test_passing_cimultidict_to_web_response_not_mutated( async def test_stream_response_sends_headers_immediately() -> None: """Test that StreamResponse sends headers immediately.""" - writer = mock.create_autospec(StreamWriter, spec_set=True) - writer.write_headers = mock.AsyncMock() - writer.send_headers = mock.Mock() - writer.write_eof = mock.AsyncMock() + writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", writer=writer) resp = web.StreamResponse() @@ -1535,10 +1530,7 @@ async def test_stream_response_sends_headers_immediately() -> None: async def test_response_buffers_headers() -> None: """Test that Response buffers headers for packet coalescing.""" - writer = mock.create_autospec(StreamWriter, spec_set=True) - writer.write_headers = mock.AsyncMock() - writer.send_headers = mock.Mock() - writer.write_eof = mock.AsyncMock() + writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", writer=writer) resp = web.Response(body=b"hello") diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py index 81308af4d54..4a3545b35e9 100644 --- a/tests/test_web_sendfile.py +++ b/tests/test_web_sendfile.py @@ -139,10 +139,7 @@ def test_status_controlled_by_user(loop: asyncio.AbstractEventLoop) -> None: async def test_file_response_sends_headers_immediately() -> None: """Test that FileResponse sends headers immediately (inherits from StreamResponse).""" - writer = mock.create_autospec(StreamWriter, spec_set=True) - writer.write_headers = mock.AsyncMock() - writer.send_headers = mock.Mock() - writer.write_eof = mock.AsyncMock() + writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) request = make_mocked_request("GET", "http://python.org/logo.png", writer=writer) diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 14032f42e83..a333198b108 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -16,8 +16,7 @@ @pytest.fixture def protocol() -> mock.Mock: - ret = mock.Mock() - ret._drain_helper = mock.AsyncMock() + ret = mock.create_autospec(BaseProtocol, spec_set=True, instance=True) return ret