Skip to content

Commit 46aa000

Browse files
committed
fix: check writer is closing in AIOKafkaConnection.send
1 parent 01c60cd commit 46aa000

File tree

3 files changed

+83
-6
lines changed

3 files changed

+83
-6
lines changed

aiokafka/conn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,12 @@ def send(self, request, expect_response=True):
457457
f"No connection to broker at {self._host}:{self._port}"
458458
)
459459

460+
if self._writer.is_closing():
461+
self.close(reason=CloseReason.CONNECTION_BROKEN)
462+
raise Errors.KafkaConnectionError(
463+
f"Connection at {self._host}:{self._port} is closing"
464+
)
465+
460466
correlation_id = self._next_correlation_id()
461467
header = request.build_request_header(
462468
correlation_id=correlation_id, client_id=self._client_id

requirements-ci.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ Pygments==2.15.0
1414
gssapi==1.8.3
1515
async-timeout==4.0.1
1616
cramjam==2.8.0
17+
uvloop==0.19.0

tests/test_conn.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import asyncio
22
import gc
3+
import socket
34
import struct
4-
from typing import Any
5+
from typing import Any, AsyncIterable, Iterable, Tuple
56
from unittest import mock
67

78
import pytest
9+
import pytest_asyncio
10+
import uvloop
811

912
from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn
1013
from aiokafka.errors import (
@@ -144,7 +147,7 @@ async def test_send_to_closed(self):
144147
with self.assertRaises(KafkaConnectionError):
145148
await conn.send(request)
146149

147-
conn._writer = mock.MagicMock()
150+
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
148151
conn._writer.write.side_effect = OSError("mocked writer is closed")
149152

150153
with self.assertRaises(KafkaConnectionError):
@@ -173,7 +176,7 @@ async def second_resp(*args: Any, **kw: Any):
173176
return resp
174177

175178
reader.readexactly.side_effect = [first_resp(), second_resp()]
176-
writer = mock.MagicMock()
179+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
177180

178181
conn._reader = reader
179182
conn._writer = writer
@@ -208,7 +211,7 @@ async def second_resp(*args: Any, **kw: Any):
208211
return resp
209212

210213
reader.readexactly.side_effect = [first_resp(), second_resp()]
211-
writer = mock.MagicMock()
214+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
212215

213216
conn._reader = reader
214217
conn._writer = writer
@@ -237,7 +240,7 @@ async def invoke_osserror(*a, **kw):
237240
# setup reader
238241
reader = mock.MagicMock()
239242
reader.readexactly.return_value = invoke_osserror()
240-
writer = mock.MagicMock()
243+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
241244

242245
conn._reader = reader
243246
conn._writer = writer
@@ -394,7 +397,7 @@ async def test__send_sasl_token(self):
394397
# setup connection with mocked transport and protocol
395398
conn = AIOKafkaConnection(host="", port=9999)
396399
conn.close = mock.MagicMock()
397-
conn._writer = mock.MagicMock()
400+
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
398401
out_buffer = []
399402
conn._writer.write = mock.Mock(side_effect=out_buffer.append)
400403
conn._reader = mock.MagicMock()
@@ -424,3 +427,70 @@ async def test__send_sasl_token(self):
424427
conn._send_sasl_token(b"Super data")
425428
# We don't need to close 2ce
426429
self.assertEqual(conn.close.call_count, 1)
430+
431+
432+
class TestClosedSocket:
433+
@pytest.fixture(
434+
params=(
435+
asyncio.DefaultEventLoopPolicy(),
436+
uvloop.EventLoopPolicy(),
437+
),
438+
)
439+
def event_loop(
440+
self, request: pytest.FixtureRequest
441+
) -> Iterable[asyncio.AbstractEventLoop]:
442+
loop: asyncio.AbstractEventLoop = request.param.new_event_loop()
443+
yield loop
444+
loop.close()
445+
446+
@pytest.fixture()
447+
def server(self, unused_tcp_port: int) -> Iterable[Tuple[str, int, socket.socket]]:
448+
host = "localhost"
449+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
450+
sock.bind((host, unused_tcp_port))
451+
sock.listen(8)
452+
sock.setblocking(False)
453+
454+
yield host, unused_tcp_port, sock
455+
456+
sock.close()
457+
458+
@pytest_asyncio.fixture()
459+
async def conn(
460+
self, server: Tuple[str, int, socket.socket]
461+
) -> AsyncIterable[AIOKafkaConnection]:
462+
host, port, _ = server
463+
464+
conn = AIOKafkaConnection(host=host, port=port, request_timeout_ms=1000)
465+
conn._create_reader_task = mock.Mock()
466+
467+
yield conn
468+
469+
fut = conn.close()
470+
if fut:
471+
await fut
472+
473+
@pytest.mark.asyncio
474+
async def test_send_to_closed_socket(
475+
self, server: Tuple[str, int, socket.socket], conn: AIOKafkaConnection
476+
) -> None:
477+
host, port, sock = server
478+
479+
request = MetadataRequest([])
480+
481+
with pytest.raises(
482+
KafkaConnectionError,
483+
match=f"KafkaConnectionError: No connection to broker at {host}:{port}",
484+
):
485+
await conn.send(request)
486+
487+
await conn.connect()
488+
489+
sock.close()
490+
await asyncio.sleep(0.1)
491+
492+
with pytest.raises(
493+
KafkaConnectionError,
494+
match=f"KafkaConnectionError: Connection at {host}:{port} is closing",
495+
):
496+
await conn.send(request)

0 commit comments

Comments
 (0)