Skip to content

Commit 7b390b8

Browse files
committed
fix: check writer is closing in AIOKafkaConnection.send
1 parent 01c60cd commit 7b390b8

File tree

3 files changed

+97
-6
lines changed

3 files changed

+97
-6
lines changed

aiokafka/conn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,12 @@ def send(self, request, expect_response=True):
457457
f"No connection to broker at {self._host}:{self._port}"
458458
)
459459

460+
if self._writer.is_closing():
461+
self.close(reason=CloseReason.CONNECTION_BROKEN)
462+
raise Errors.KafkaConnectionError(
463+
f"Connection at {self._host}:{self._port} is closing"
464+
)
465+
460466
correlation_id = self._next_correlation_id()
461467
header = request.build_request_header(
462468
correlation_id=correlation_id, client_id=self._client_id

requirements-ci.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ Pygments==2.15.0
1414
gssapi==1.8.3
1515
async-timeout==4.0.1
1616
cramjam==2.8.0
17+
uvloop==0.19.0

tests/test_conn.py

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import asyncio
22
import gc
3+
import socket
34
import struct
4-
from typing import Any
5+
import sys
6+
from typing import Any, AsyncIterable, Iterable, Tuple
57
from unittest import mock
68

79
import pytest
10+
import pytest_asyncio
811

912
from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn
1013
from aiokafka.errors import (
@@ -144,7 +147,7 @@ async def test_send_to_closed(self):
144147
with self.assertRaises(KafkaConnectionError):
145148
await conn.send(request)
146149

147-
conn._writer = mock.MagicMock()
150+
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
148151
conn._writer.write.side_effect = OSError("mocked writer is closed")
149152

150153
with self.assertRaises(KafkaConnectionError):
@@ -173,7 +176,7 @@ async def second_resp(*args: Any, **kw: Any):
173176
return resp
174177

175178
reader.readexactly.side_effect = [first_resp(), second_resp()]
176-
writer = mock.MagicMock()
179+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
177180

178181
conn._reader = reader
179182
conn._writer = writer
@@ -208,7 +211,7 @@ async def second_resp(*args: Any, **kw: Any):
208211
return resp
209212

210213
reader.readexactly.side_effect = [first_resp(), second_resp()]
211-
writer = mock.MagicMock()
214+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
212215

213216
conn._reader = reader
214217
conn._writer = writer
@@ -237,7 +240,7 @@ async def invoke_osserror(*a, **kw):
237240
# setup reader
238241
reader = mock.MagicMock()
239242
reader.readexactly.return_value = invoke_osserror()
240-
writer = mock.MagicMock()
243+
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
241244

242245
conn._reader = reader
243246
conn._writer = writer
@@ -394,7 +397,7 @@ async def test__send_sasl_token(self):
394397
# setup connection with mocked transport and protocol
395398
conn = AIOKafkaConnection(host="", port=9999)
396399
conn.close = mock.MagicMock()
397-
conn._writer = mock.MagicMock()
400+
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
398401
out_buffer = []
399402
conn._writer.write = mock.Mock(side_effect=out_buffer.append)
400403
conn._reader = mock.MagicMock()
@@ -424,3 +427,84 @@ async def test__send_sasl_token(self):
424427
conn._send_sasl_token(b"Super data")
425428
# We don't need to close 2ce
426429
self.assertEqual(conn.close.call_count, 1)
430+
431+
432+
class TestClosedSocket:
433+
@pytest.fixture(
434+
params=(
435+
pytest.param("asyncio", id="asyncio"),
436+
pytest.param(
437+
"uvloop",
438+
marks=pytest.mark.skipif(
439+
sys.platform == "win32",
440+
reason="uvloop does not support the windows",
441+
),
442+
id="uvloop",
443+
),
444+
),
445+
)
446+
def event_loop(
447+
self, request: pytest.FixtureRequest
448+
) -> Iterable[asyncio.AbstractEventLoop]:
449+
if request.param == "asyncio":
450+
policy = asyncio.DefaultEventLoopPolicy()
451+
elif request.param == "uvloop":
452+
import uvloop
453+
454+
policy = uvloop.EventLoopPolicy()
455+
456+
loop: asyncio.AbstractEventLoop = policy.new_event_loop()
457+
yield loop
458+
loop.close()
459+
460+
@pytest.fixture()
461+
def server(self, unused_tcp_port: int) -> Iterable[Tuple[str, int, socket.socket]]:
462+
host = "localhost"
463+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
464+
sock.bind((host, unused_tcp_port))
465+
sock.listen(8)
466+
sock.setblocking(False)
467+
468+
yield host, unused_tcp_port, sock
469+
470+
sock.close()
471+
472+
@pytest_asyncio.fixture()
473+
async def conn(
474+
self, server: Tuple[str, int, socket.socket]
475+
) -> AsyncIterable[AIOKafkaConnection]:
476+
host, port, _ = server
477+
478+
conn = AIOKafkaConnection(host=host, port=port, request_timeout_ms=1000)
479+
conn._create_reader_task = mock.Mock()
480+
481+
yield conn
482+
483+
fut = conn.close()
484+
if fut:
485+
await fut
486+
487+
@pytest.mark.asyncio
488+
async def test_send_to_closed_socket(
489+
self, server: Tuple[str, int, socket.socket], conn: AIOKafkaConnection
490+
) -> None:
491+
host, port, sock = server
492+
493+
request = MetadataRequest([])
494+
495+
with pytest.raises(
496+
KafkaConnectionError,
497+
match=f"KafkaConnectionError: No connection to broker at {host}:{port}",
498+
):
499+
await conn.send(request)
500+
501+
await conn.connect()
502+
503+
sock.close()
504+
await asyncio.sleep(0.1)
505+
506+
with pytest.raises(
507+
KafkaConnectionError,
508+
match=f"KafkaConnectionError: Connection at {host}:{port} is closing",
509+
):
510+
await conn.send(request)

0 commit comments

Comments
 (0)