Skip to content

Commit afb225f

Browse files
[PR #9757/b2c845bb backport][3.11] Avoid creating body writer task when there is no body (#9759)
Co-authored-by: J. Nick Koston <[email protected]>
1 parent 1d761e1 commit afb225f

9 files changed

+75
-24
lines changed

CHANGES/9757.misc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improved performance of sending HTTP requests when there is no body -- by :user:`bdraco`.

aiohttp/base_protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def connected(self) -> bool:
2929
"""Return True if the connection is open."""
3030
return self.transport is not None
3131

32+
@property
33+
def writing_paused(self) -> bool:
34+
return self._paused
35+
3236
def pause_writing(self) -> None:
3337
assert not self._paused
3438
self._paused = True

aiohttp/client_reqrep.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -736,22 +736,28 @@ async def send(self, conn: "Connection") -> "ClientResponse":
736736
v = self.version
737737
status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
738738
await writer.write_headers(status_line, self.headers)
739-
coro = self.write_bytes(writer, conn)
740-
741739
task: Optional["asyncio.Task[None]"]
742-
if sys.version_info >= (3, 12):
743-
# Optimization for Python 3.12, try to write
744-
# bytes immediately to avoid having to schedule
745-
# the task on the event loop.
746-
task = asyncio.Task(coro, loop=self.loop, eager_start=True)
740+
if self.body or self._continue is not None or protocol.writing_paused:
741+
coro = self.write_bytes(writer, conn)
742+
if sys.version_info >= (3, 12):
743+
# Optimization for Python 3.12, try to write
744+
# bytes immediately to avoid having to schedule
745+
# the task on the event loop.
746+
task = asyncio.Task(coro, loop=self.loop, eager_start=True)
747+
else:
748+
task = self.loop.create_task(coro)
749+
if task.done():
750+
task = None
751+
else:
752+
self._writer = task
747753
else:
748-
task = self.loop.create_task(coro)
749-
750-
if task.done():
754+
# We have nothing to write because
755+
# - there is no body
756+
# - the protocol does not have writing paused
757+
# - we are not waiting for a 100-continue response
758+
protocol.start_timeout()
759+
writer.set_eof()
751760
task = None
752-
else:
753-
self._writer = task
754-
755761
response_class = self.response_class
756762
assert response_class is not None
757763
self.response = response_class(

aiohttp/http_writer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ async def write_headers(
130130
buf = _serialize_headers(status_line, headers)
131131
self._write(buf)
132132

133+
def set_eof(self) -> None:
134+
"""Indicate that the message is complete."""
135+
self._eof = True
136+
133137
async def write_eof(self, chunk: bytes = b"") -> None:
134138
if self._eof:
135139
return

tests/test_base_protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ async def test_pause_writing() -> None:
1818
loop = asyncio.get_event_loop()
1919
pr = BaseProtocol(loop)
2020
assert not pr._paused
21+
assert pr.writing_paused is False
2122
pr.pause_writing()
2223
assert pr._paused
24+
assert pr.writing_paused is True # type: ignore[unreachable]
2325

2426

2527
async def test_pause_reading_no_transport() -> None:

tests/test_benchmarks_client_request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ class MockProtocol(asyncio.BaseProtocol):
8989
def __init__(self) -> None:
9090
self.transport = MockTransport()
9191

92+
@property
93+
def writing_paused(self) -> bool:
94+
return False
95+
9296
async def _drain_helper(self) -> None:
9397
"""Swallow drain."""
9498

tests/test_client_functional.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tarfile
1313
import time
1414
import zipfile
15-
from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Type
15+
from typing import Any, AsyncIterator, Awaitable, Callable, List, Type
1616
from unittest import mock
1717

1818
import pytest
@@ -1503,10 +1503,7 @@ async def handler(request):
15031503
assert 200 == resp.status
15041504

15051505

1506-
@pytest.mark.parametrize("data", (None, b""))
1507-
async def test_GET_DEFLATE(
1508-
aiohttp_client: AiohttpClient, data: Optional[bytes]
1509-
) -> None:
1506+
async def test_GET_DEFLATE(aiohttp_client: AiohttpClient) -> None:
15101507
async def handler(request: web.Request) -> web.Response:
15111508
return web.json_response({"ok": True})
15121509

@@ -1529,7 +1526,7 @@ async def write_bytes(
15291526
app.router.add_get("/", handler)
15301527
client = await aiohttp_client(app)
15311528

1532-
async with client.get("/", data=data, compress=True) as resp:
1529+
async with client.get("/", data=b"", compress=True) as resp:
15331530
assert resp.status == 200
15341531
content = await resp.json()
15351532
assert content == {"ok": True}
@@ -1539,6 +1536,24 @@ async def write_bytes(
15391536
write_mock.assert_not_called()
15401537

15411538

1539+
async def test_GET_DEFLATE_no_body(aiohttp_client: AiohttpClient) -> None:
1540+
async def handler(request: web.Request) -> web.Response:
1541+
return web.json_response({"ok": True})
1542+
1543+
with mock.patch.object(ClientRequest, "write_bytes") as mock_write_bytes:
1544+
app = web.Application()
1545+
app.router.add_get("/", handler)
1546+
client = await aiohttp_client(app)
1547+
1548+
async with client.get("/", data=None, compress=True) as resp:
1549+
assert resp.status == 200
1550+
content = await resp.json()
1551+
assert content == {"ok": True}
1552+
1553+
# No chunks should have been sent for an empty body.
1554+
mock_write_bytes.assert_not_called()
1555+
1556+
15421557
async def test_POST_DATA_DEFLATE(aiohttp_client: AiohttpClient) -> None:
15431558
async def handler(request: web.Request) -> web.Response:
15441559
data = await request.post()

tests/test_client_session.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def to_url(path: str) -> URL:
836836
assert to_trace_urls(on_request_redirect) == []
837837
assert to_trace_urls(on_request_end) == [to_url("/?x=0")]
838838
assert to_trace_urls(on_request_exception) == []
839-
assert to_trace_urls(on_request_chunk_sent) == [to_url("/?x=0")]
839+
assert to_trace_urls(on_request_chunk_sent) == []
840840
assert to_trace_urls(on_response_chunk_received) == [to_url("/?x=0")]
841841
assert to_trace_urls(on_request_headers_sent) == [to_url("/?x=0")]
842842

@@ -852,10 +852,7 @@ def to_url(path: str) -> URL:
852852
assert to_trace_urls(on_request_redirect) == [to_url("/redirect?x=0")]
853853
assert to_trace_urls(on_request_end) == [to_url("/")]
854854
assert to_trace_urls(on_request_exception) == []
855-
assert to_trace_urls(on_request_chunk_sent) == [
856-
to_url("/redirect?x=0"),
857-
to_url("/"),
858-
]
855+
assert to_trace_urls(on_request_chunk_sent) == []
859856
assert to_trace_urls(on_response_chunk_received) == [to_url("/")]
860857
assert to_trace_urls(on_request_headers_sent) == [
861858
to_url("/redirect?x=0"),

tests/test_http_writer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Tests for aiohttp/http_writer.py
22
import array
3+
import asyncio
34
from unittest import mock
45

56
import pytest
67
from multidict import CIMultiDict
78

89
from aiohttp import ClientConnectionResetError, http
10+
from aiohttp.base_protocol import BaseProtocol
911
from aiohttp.test_utils import make_mocked_coro
1012

1113

@@ -275,3 +277,19 @@ async def test_write_headers_prevents_injection(protocol, transport, loop) -> No
275277
wrong_headers = CIMultiDict({"Content-Length": "256\r\nSet-Cookie: abc=123"})
276278
with pytest.raises(ValueError):
277279
await msg.write_headers(status_line, wrong_headers)
280+
281+
282+
async def test_set_eof_after_write_headers(
283+
protocol: BaseProtocol,
284+
transport: mock.Mock,
285+
loop: asyncio.AbstractEventLoop,
286+
) -> None:
287+
msg = http.StreamWriter(protocol, loop)
288+
status_line = "HTTP/1.1 200 OK"
289+
good_headers = CIMultiDict({"Set-Cookie": "abc=123"})
290+
await msg.write_headers(status_line, good_headers)
291+
assert transport.write.called
292+
transport.write.reset_mock()
293+
msg.set_eof()
294+
await msg.write_eof()
295+
assert not transport.write.called

0 commit comments

Comments
 (0)