Skip to content

Commit b6fec2d

Browse files
committed
fix: check writer is closing in AIOKafkaConnection.send
1 parent 01c60cd commit b6fec2d

File tree

3 files changed

+93
-6
lines changed

3 files changed

+93
-6
lines changed

aiokafka/conn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,12 @@ def send(self, request, expect_response=True):
457457
f"No connection to broker at {self._host}:{self._port}"
458458
)
459459

460+
if self._writer.is_closing():
461+
self.close(reason=CloseReason.CONNECTION_BROKEN)
462+
raise Errors.KafkaConnectionError(
463+
f"Connection at {self._host}:{self._port} is closing"
464+
)
465+
460466
correlation_id = self._next_correlation_id()
461467
header = request.build_request_header(
462468
correlation_id=correlation_id, client_id=self._client_id

requirements-ci.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ Pygments==2.15.0
1414
gssapi==1.8.3
1515
async-timeout==4.0.1
1616
cramjam==2.8.0
17+
uvloop==0.19.0

tests/test_conn.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import asyncio
22
import gc
3+
import socket
34
import struct
4-
from typing import Any
5+
import sys
6+
from typing import Any, AsyncIterable, Iterable, Tuple
57
from unittest import mock
68

79
import pytest
10+
import pytest_asyncio
811

912
from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn
1013
from aiokafka.errors import (
@@ -144,7 +147,7 @@ async def test_send_to_closed(self):
144147
with self.assertRaises(KafkaConnectionError):
145148
await conn.send(request)
146149

147-
conn._writer = mock.MagicMock()
150+
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
148151
conn._writer.write.side_effect = OSError("mocked writer is closed")
149152

150153
with self.assertRaises(KafkaConnectionError):
@@ -173,7 +176,7 @@ async def second_resp(*args: Any, **kw: Any):
173176
return resp
174177

175178
reader.readexactly.side_effect = [first_resp(), second_resp()]
176-
writer = mock.MagicMock()
179+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
177180

178181
conn._reader = reader
179182
conn._writer = writer
@@ -208,7 +211,7 @@ async def second_resp(*args: Any, **kw: Any):
208211
return resp
209212

210213
reader.readexactly.side_effect = [first_resp(), second_resp()]
211-
writer = mock.MagicMock()
214+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
212215

213216
conn._reader = reader
214217
conn._writer = writer
@@ -237,7 +240,7 @@ async def invoke_osserror(*a, **kw):
237240
# setup reader
238241
reader = mock.MagicMock()
239242
reader.readexactly.return_value = invoke_osserror()
240-
writer = mock.MagicMock()
243+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
241244

242245
conn._reader = reader
243246
conn._writer = writer
@@ -394,7 +397,7 @@ async def test__send_sasl_token(self):
394397
# setup connection with mocked transport and protocol
395398
conn = AIOKafkaConnection(host="", port=9999)
396399
conn.close = mock.MagicMock()
397-
conn._writer = mock.MagicMock()
400+
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
398401
out_buffer = []
399402
conn._writer.write = mock.Mock(side_effect=out_buffer.append)
400403
conn._reader = mock.MagicMock()
@@ -424,3 +427,80 @@ async def test__send_sasl_token(self):
424427
conn._send_sasl_token(b"Super data")
425428
# We don't need to close 2ce
426429
self.assertEqual(conn.close.call_count, 1)
430+
431+
432+
@pytest.mark.skipif(sys.platform == "win32", reason="Uvloop doesn't support Windows")
433+
class TestClosedSocket:
434+
@pytest.fixture(
435+
params=(
436+
pytest.param("asyncio", id="asyncio"),
437+
pytest.param("uvloop", id="uvloop"),
438+
),
439+
)
440+
def event_loop(
441+
self, request: pytest.FixtureRequest
442+
) -> Iterable[asyncio.AbstractEventLoop]:
443+
if request.param == "asyncio":
444+
policy = asyncio.DefaultEventLoopPolicy()
445+
elif request.param == "uvloop":
446+
import uvloop
447+
448+
policy = uvloop.EventLoopPolicy()
449+
else:
450+
raise ValueError(f"loop {request.param} is not supported")
451+
452+
loop: asyncio.AbstractEventLoop = policy.new_event_loop()
453+
yield loop
454+
loop.close()
455+
456+
@pytest.fixture()
457+
def server(self, unused_tcp_port: int) -> Iterable[Tuple[str, int, socket.socket]]:
458+
host = "localhost"
459+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
460+
sock.bind((host, unused_tcp_port))
461+
sock.listen(8)
462+
sock.setblocking(False)
463+
464+
yield host, unused_tcp_port, sock
465+
466+
sock.close()
467+
468+
@pytest_asyncio.fixture()
469+
async def conn(
470+
self, server: Tuple[str, int, socket.socket]
471+
) -> AsyncIterable[AIOKafkaConnection]:
472+
host, port, _ = server
473+
474+
conn = AIOKafkaConnection(host=host, port=port, request_timeout_ms=1000)
475+
conn._create_reader_task = mock.Mock()
476+
477+
yield conn
478+
479+
fut = conn.close()
480+
if fut:
481+
await fut
482+
483+
@pytest.mark.asyncio
484+
async def test_send_to_closed_socket(
485+
self, server: Tuple[str, int, socket.socket], conn: AIOKafkaConnection
486+
) -> None:
487+
host, port, sock = server
488+
489+
request = MetadataRequest([])
490+
491+
with pytest.raises(
492+
KafkaConnectionError,
493+
match=f"KafkaConnectionError: No connection to broker at {host}:{port}",
494+
):
495+
await conn.send(request)
496+
497+
await conn.connect()
498+
499+
sock.close()
500+
await asyncio.sleep(0.1)
501+
502+
with pytest.raises(
503+
KafkaConnectionError,
504+
match=f"KafkaConnectionError: Connection at {host}:{port} is closing",
505+
):
506+
await conn.send(request)

0 commit comments

Comments
 (0)