|
1 | 1 | import asyncio |
2 | 2 | import gc |
| 3 | +import socket |
3 | 4 | import struct |
4 | | -from typing import Any |
| 5 | +import sys |
| 6 | +from typing import Any, AsyncIterable, Iterable, Tuple |
5 | 7 | from unittest import mock |
6 | 8 |
|
7 | 9 | import pytest |
| 10 | +import pytest_asyncio |
8 | 11 |
|
9 | 12 | from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn |
10 | 13 | from aiokafka.errors import ( |
@@ -144,7 +147,7 @@ async def test_send_to_closed(self): |
144 | 147 | with self.assertRaises(KafkaConnectionError): |
145 | 148 | await conn.send(request) |
146 | 149 |
|
147 | | - conn._writer = mock.MagicMock() |
| 150 | + conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) |
148 | 151 | conn._writer.write.side_effect = OSError("mocked writer is closed") |
149 | 152 |
|
150 | 153 | with self.assertRaises(KafkaConnectionError): |
@@ -173,7 +176,7 @@ async def second_resp(*args: Any, **kw: Any): |
173 | 176 | return resp |
174 | 177 |
|
175 | 178 | reader.readexactly.side_effect = [first_resp(), second_resp()] |
176 | | - writer = mock.MagicMock() |
| 179 | + writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) |
177 | 180 |
|
178 | 181 | conn._reader = reader |
179 | 182 | conn._writer = writer |
@@ -208,7 +211,7 @@ async def second_resp(*args: Any, **kw: Any): |
208 | 211 | return resp |
209 | 212 |
|
210 | 213 | reader.readexactly.side_effect = [first_resp(), second_resp()] |
211 | | - writer = mock.MagicMock() |
| 214 | + writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) |
212 | 215 |
|
213 | 216 | conn._reader = reader |
214 | 217 | conn._writer = writer |
@@ -237,7 +240,7 @@ async def invoke_osserror(*a, **kw): |
237 | 240 | # setup reader |
238 | 241 | reader = mock.MagicMock() |
239 | 242 | reader.readexactly.return_value = invoke_osserror() |
240 | | - writer = mock.MagicMock() |
| 243 | + writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) |
241 | 244 |
|
242 | 245 | conn._reader = reader |
243 | 246 | conn._writer = writer |
@@ -394,7 +397,7 @@ async def test__send_sasl_token(self): |
394 | 397 | # setup connection with mocked transport and protocol |
395 | 398 | conn = AIOKafkaConnection(host="", port=9999) |
396 | 399 | conn.close = mock.MagicMock() |
397 | | - conn._writer = mock.MagicMock() |
| 400 | + conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) |
398 | 401 | out_buffer = [] |
399 | 402 | conn._writer.write = mock.Mock(side_effect=out_buffer.append) |
400 | 403 | conn._reader = mock.MagicMock() |
@@ -424,3 +427,84 @@ async def test__send_sasl_token(self): |
424 | 427 | conn._send_sasl_token(b"Super data") |
425 | 428 | # We don't need to close 2ce |
426 | 429 | self.assertEqual(conn.close.call_count, 1) |
| 430 | + |
| 431 | + |
| 432 | +class TestClosedSocket: |
| 433 | + @pytest.fixture( |
| 434 | + params=( |
| 435 | + pytest.param("asyncio", id="asyncio"), |
| 436 | + pytest.param( |
| 437 | + "uvloop", |
| 438 | + marks=pytest.mark.skipif( |
| 439 | + sys.platform == "win32", |
| 440 | + reason="uvloop does not support the windows", |
| 441 | + ), |
| 442 | + id="uvloop", |
| 443 | + ), |
| 444 | + ), |
| 445 | + ) |
| 446 | + def event_loop( |
| 447 | + self, request: pytest.FixtureRequest |
| 448 | + ) -> Iterable[asyncio.AbstractEventLoop]: |
| 449 | + if request.param == "asyncio": |
| 450 | + policy = asyncio.DefaultEventLoopPolicy() |
| 451 | + elif request.param == "uvloop": |
| 452 | + import uvloop |
| 453 | + |
| 454 | + policy = uvloop.EventLoopPolicy() |
| 455 | + |
| 456 | + loop: asyncio.AbstractEventLoop = policy.new_event_loop() |
| 457 | + yield loop |
| 458 | + loop.close() |
| 459 | + |
| 460 | + @pytest.fixture() |
| 461 | + def server(self, unused_tcp_port: int) -> Iterable[Tuple[str, int, socket.socket]]: |
| 462 | + host = "localhost" |
| 463 | + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 464 | + sock.bind((host, unused_tcp_port)) |
| 465 | + sock.listen(8) |
| 466 | + sock.setblocking(False) |
| 467 | + |
| 468 | + yield host, unused_tcp_port, sock |
| 469 | + |
| 470 | + sock.close() |
| 471 | + |
| 472 | + @pytest_asyncio.fixture() |
| 473 | + async def conn( |
| 474 | + self, server: Tuple[str, int, socket.socket] |
| 475 | + ) -> AsyncIterable[AIOKafkaConnection]: |
| 476 | + host, port, _ = server |
| 477 | + |
| 478 | + conn = AIOKafkaConnection(host=host, port=port, request_timeout_ms=1000) |
| 479 | + conn._create_reader_task = mock.Mock() |
| 480 | + |
| 481 | + yield conn |
| 482 | + |
| 483 | + fut = conn.close() |
| 484 | + if fut: |
| 485 | + await fut |
| 486 | + |
| 487 | + @pytest.mark.asyncio |
| 488 | + async def test_send_to_closed_socket( |
| 489 | + self, server: Tuple[str, int, socket.socket], conn: AIOKafkaConnection |
| 490 | + ) -> None: |
| 491 | + host, port, sock = server |
| 492 | + |
| 493 | + request = MetadataRequest([]) |
| 494 | + |
| 495 | + with pytest.raises( |
| 496 | + KafkaConnectionError, |
| 497 | + match=f"KafkaConnectionError: No connection to broker at {host}:{port}", |
| 498 | + ): |
| 499 | + await conn.send(request) |
| 500 | + |
| 501 | + await conn.connect() |
| 502 | + |
| 503 | + sock.close() |
| 504 | + await asyncio.sleep(0.1) |
| 505 | + |
| 506 | + with pytest.raises( |
| 507 | + KafkaConnectionError, |
| 508 | + match=f"KafkaConnectionError: Connection at {host}:{port} is closing", |
| 509 | + ): |
| 510 | + await conn.send(request) |
0 commit comments