diff --git a/CHANGES/1129.feature b/CHANGES/1129.feature new file mode 100644 index 00000000..e5ed08d8 --- /dev/null +++ b/CHANGES/1129.feature @@ -0,0 +1 @@ +Implement asynchronous context manager protocol on ``AIOKafkaAdminClient`` (PR #1129 by @PeterJCLaw) diff --git a/aiokafka/admin/client.py b/aiokafka/admin/client.py index 42e5e666..07526a96 100644 --- a/aiokafka/admin/client.py +++ b/aiokafka/admin/client.py @@ -3,9 +3,11 @@ from collections import defaultdict from collections.abc import Sequence from ssl import SSLContext +from types import TracebackType from typing import Any import async_timeout +from typing_extensions import Self from aiokafka import __version__ from aiokafka.client import AIOKafkaClient @@ -160,6 +162,18 @@ async def start(self): log.debug("AIOKafkaAdminClient started") self._started = True + async def __aenter__(self) -> Self: + await self.start() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + await self.close() + def _matching_api_version(self, operation: Sequence[type[Request]]) -> int: """Find the latest version of the protocol operation supported by both this library and the broker. diff --git a/tests/test_admin.py b/tests/test_admin.py index a81b9663..02edc755 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -25,6 +25,38 @@ async def test_metadata(self): assert metadata.topics is not None assert len(metadata.brokers) == 1 + @kafka_versions(">=0.10.0.0") + @run_until_complete + async def test_context_manager(self): + async with AIOKafkaAdminClient(bootstrap_servers=self.hosts) as admin: + assert admin._started + + # Arbitrary testing + metadata = await admin._get_cluster_metadata() + assert metadata.brokers is not None + assert metadata.topics is not None + assert len(metadata.brokers) == 1 + + assert admin._closed + + # Test error case too + class FakeError: + pass + + with pytest.raises(FakeError): + async with AIOKafkaAdminClient(bootstrap_servers=self.hosts) as admin: + assert admin._started + + # Arbitrary testing + metadata = await admin._get_cluster_metadata() + assert metadata.brokers is not None + assert metadata.topics is not None + assert len(metadata.brokers) == 1 + + raise FakeError + + assert admin._closed + @kafka_versions(">=0.10.1.0") @run_until_complete async def test_create_topics(self):