Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Bugfixes:

* Make KafkaStorageError retriable after metadata refresh like in other
implementations (pr #1115 by @omerhadari)
* Fix producer and consumer requiring closing after `start` or `__aenter__` raise an exception.
(issue #1130, pr #1131 by @calgray)


Misc:
Expand Down
153 changes: 80 additions & 73 deletions aiokafka/consumer/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import traceback
import warnings
from contextlib import AsyncExitStack

from aiokafka import __version__
from aiokafka.abc import ConsumerRebalanceListener
Expand Down Expand Up @@ -335,7 +336,6 @@ def __init__(

if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
self._closed = False

if topics:
topics = self._validate_topics(topics)
Expand Down Expand Up @@ -368,83 +368,94 @@ async def start(self):
self._loop is get_running_loop()
), "Please create objects with the same loop as running with"
assert self._fetcher is None, "Did you call `start` twice?"
await self._client.bootstrap()
await self._wait_topics()

if self._client.api_version < (0, 9):
raise ValueError(f"Unsupported Kafka version: {self._client.api_version}")
async with AsyncExitStack() as stack:
await self._client.bootstrap()
stack.push_async_callback(self._client.close)
await self._wait_topics()

if (
self._isolation_level == "read_committed"
and self._client.api_version < (0, 11) # fmt: skip
):
raise UnsupportedVersionError(
"`read_committed` isolation_level available only for Brokers "
"0.11 and above"
)
if self._client.api_version < (0, 9):
raise ValueError(
f"Unsupported Kafka version: {self._client.api_version}"
)

self._fetcher = Fetcher(
self._client,
self._subscription,
key_deserializer=self._key_deserializer,
value_deserializer=self._value_deserializer,
fetch_min_bytes=self._fetch_min_bytes,
fetch_max_bytes=self._fetch_max_bytes,
fetch_max_wait_ms=self._fetch_max_wait_ms,
max_partition_fetch_bytes=self._max_partition_fetch_bytes,
check_crcs=self._check_crcs,
fetcher_timeout=self._consumer_timeout,
retry_backoff_ms=self._retry_backoff_ms,
auto_offset_reset=self._auto_offset_reset,
isolation_level=self._isolation_level,
)
if (
self._isolation_level == "read_committed"
and self._client.api_version < (0, 11) # fmt: skip
):
raise UnsupportedVersionError(
"`read_committed` isolation_level available only for Brokers "
"0.11 and above"
)

if self._group_id is not None:
# using group coordinator for automatic partitions assignment
self._coordinator = GroupCoordinator(
self._fetcher = Fetcher(
self._client,
self._subscription,
group_id=self._group_id,
group_instance_id=self._group_instance_id,
heartbeat_interval_ms=self._heartbeat_interval_ms,
session_timeout_ms=self._session_timeout_ms,
key_deserializer=self._key_deserializer,
value_deserializer=self._value_deserializer,
fetch_min_bytes=self._fetch_min_bytes,
fetch_max_bytes=self._fetch_max_bytes,
fetch_max_wait_ms=self._fetch_max_wait_ms,
max_partition_fetch_bytes=self._max_partition_fetch_bytes,
check_crcs=self._check_crcs,
fetcher_timeout=self._consumer_timeout,
retry_backoff_ms=self._retry_backoff_ms,
enable_auto_commit=self._enable_auto_commit,
auto_commit_interval_ms=self._auto_commit_interval_ms,
assignors=self._partition_assignment_strategy,
exclude_internal_topics=self._exclude_internal_topics,
rebalance_timeout_ms=self._rebalance_timeout_ms,
max_poll_interval_ms=self._max_poll_interval_ms,
auto_offset_reset=self._auto_offset_reset,
isolation_level=self._isolation_level,
)
if self._subscription.subscription is not None:
if self._subscription.partitions_auto_assigned():
stack.push_async_callback(self._fetcher.close)

if self._group_id is not None:
# using group coordinator for automatic partitions assignment
self._coordinator = GroupCoordinator(
self._client,
self._subscription,
group_id=self._group_id,
group_instance_id=self._group_instance_id,
heartbeat_interval_ms=self._heartbeat_interval_ms,
session_timeout_ms=self._session_timeout_ms,
retry_backoff_ms=self._retry_backoff_ms,
enable_auto_commit=self._enable_auto_commit,
auto_commit_interval_ms=self._auto_commit_interval_ms,
assignors=self._partition_assignment_strategy,
exclude_internal_topics=self._exclude_internal_topics,
rebalance_timeout_ms=self._rebalance_timeout_ms,
max_poll_interval_ms=self._max_poll_interval_ms,
)
stack.push_async_callback(self._coordinator.close)

if self._subscription.subscription is not None:
if self._subscription.partitions_auto_assigned():
# Either we passed `topics` to constructor or `subscribe`
# was called before `start`
await self._subscription.wait_for_assignment()
else:
# `assign` was called before `start`. We did not start
# this task on that call, as coordinator was yet to be
# created
self._coordinator.start_commit_offsets_refresh_task(
self._subscription.subscription.assignment
)
else:
# Using a simple assignment coordinator for reassignment on
# metadata changes
self._coordinator = NoGroupCoordinator(
self._client,
self._subscription,
exclude_internal_topics=self._exclude_internal_topics,
)
stack.push_async_callback(self._coordinator.close)

if (
self._subscription.subscription is not None
and self._subscription.partitions_auto_assigned()
):
# Either we passed `topics` to constructor or `subscribe`
# was called before `start`
await self._subscription.wait_for_assignment()
else:
# `assign` was called before `start`. We did not start
# this task on that call, as coordinator was yet to be
# created
self._coordinator.start_commit_offsets_refresh_task(
self._subscription.subscription.assignment
)
else:
# Using a simple assignment coordinator for reassignment on
# metadata changes
self._coordinator = NoGroupCoordinator(
self._client,
self._subscription,
exclude_internal_topics=self._exclude_internal_topics,
)

if (
self._subscription.subscription is not None
and self._subscription.partitions_auto_assigned()
):
# Either we passed `topics` to constructor or `subscribe`
# was called before `start`
await self._client.force_metadata_update()
self._coordinator.assign_all_partitions(check_unknown=True)
await self._client.force_metadata_update()
self._coordinator.assign_all_partitions(check_unknown=True)
self._exit_stack = stack.pop_all()
self._closed = False

async def _wait_topics(self):
if self._subscription.subscription is not None:
Expand Down Expand Up @@ -514,11 +525,7 @@ async def stop(self):
return
log.debug("Closing the KafkaConsumer.")
self._closed = True
if self._coordinator:
await self._coordinator.close()
if self._fetcher:
await self._fetcher.close()
await self._client.close()
await self._exit_stack.aclose()
log.debug("The KafkaConsumer has closed.")

async def commit(self, offsets=None):
Expand Down
52 changes: 28 additions & 24 deletions aiokafka/producer/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import traceback
import warnings
from contextlib import AsyncExitStack

from aiokafka.client import AIOKafkaClient
from aiokafka.codec import has_gzip, has_lz4, has_snappy, has_zstd
Expand Down Expand Up @@ -324,8 +325,6 @@ def __init__(
request_timeout_ms=request_timeout_ms,
)

self._closed = False

# Warn if producer was not closed properly
# We don't attempt to close the Consumer, as __del__ is synchronous
def __del__(self, _warnings=warnings):
Expand All @@ -349,26 +348,32 @@ async def start(self):
self._loop is get_running_loop()
), "Please create objects with the same loop as running with"
log.debug("Starting the Kafka producer") # trace
await self.client.bootstrap()

if self._compression_type == "lz4":
assert self.client.api_version >= (0, 8, 2), (
"LZ4 Requires >= Kafka 0.8.2 Brokers"
) # fmt: skip
elif self._compression_type == "zstd":
assert self.client.api_version >= (2, 1, 0), (
"Zstd Requires >= Kafka 2.1.0 Brokers"
) # fmt: skip

if self._txn_manager is not None and self.client.api_version < (0, 11):
raise UnsupportedVersionError(
"Idempotent producer available only for Broker version 0.11"
" and above"
)
async with AsyncExitStack() as stack:
await self.client.bootstrap()
stack.push_async_callback(self.client.close)

if self._compression_type == "lz4":
assert self.client.api_version >= (0, 8, 2), (
"LZ4 Requires >= Kafka 0.8.2 Brokers"
) # fmt: skip
elif self._compression_type == "zstd":
assert self.client.api_version >= (2, 1, 0), (
"Zstd Requires >= Kafka 2.1.0 Brokers"
) # fmt: skip

if self._txn_manager is not None and self.client.api_version < (0, 11):
raise UnsupportedVersionError(
"Idempotent producer available only for Broker version 0.11"
" and above"
)

await self._sender.start()
self._message_accumulator.set_api_version(self.client.api_version)
self._producer_magic = 0 if self.client.api_version < (0, 10) else 1
await self._sender.start()
stack.push_async_callback(self._sender.close)

self._message_accumulator.set_api_version(self.client.api_version)
self._producer_magic = 0 if self.client.api_version < (0, 10) else 1
self._exit_stack = stack.pop_all()
self._closed = False
log.debug("Kafka producer started")

async def flush(self):
Expand All @@ -379,6 +384,7 @@ async def stop(self):
"""Flush all pending data and close all connections to kafka cluster"""
if self._closed:
return
log.debug("Closing the KafkaProducer.")
self._closed = True

# If the sender task is down there is no way for accumulator to flush
Expand All @@ -391,9 +397,7 @@ async def stop(self):
return_when=asyncio.FIRST_COMPLETED,
)

await self._sender.close()

await self.client.close()
await self._exit_stack.aclose()
log.debug("The Kafka producer has closed.")

async def partitions_for(self, topic):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import json
import time
import warnings
from contextlib import contextmanager
from unittest import mock

Expand Down Expand Up @@ -188,6 +189,36 @@ async def test_consumer_context_manager(self):
raise ValueError
assert consumer._closed

@run_until_complete
async def test_consumer_context_manager_start_error(self):
for target, group_id in [
("aiokafka.consumer.consumer.AIOKafkaClient.bootstrap", None),
("aiokafka.consumer.consumer.Fetcher.__init__", None),
("aiokafka.consumer.consumer.NoGroupCoordinator.__init__", None),
(
"aiokafka.consumer.consumer.GroupCoordinator.__init__",
f"group-{self.id()}",
),
]:
consumer = AIOKafkaConsumer(
self.topic,
group_id=group_id,
bootstrap_servers=self.hosts,
enable_auto_commit=False,
auto_offset_reset="earliest",
)

# make consumer.start() raise
with mock.patch(target, side_effect=RuntimeError("error")):
with self.assertRaises(RuntimeError):
async with consumer:
self.fail(f"{target} did not raise")

# should not require calling consumer.close()
with warnings.catch_warnings(record=True) as w:
del consumer
self.assertEqual(len(w), 0, f"{target} got unexpected warnings: {w}")

@run_until_complete
async def test_consumer_api_version(self):
await self.send_messages(0, list(range(10)))
Expand Down
20 changes: 20 additions & 0 deletions tests/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import json
import time
import warnings
import weakref
from unittest import mock

Expand Down Expand Up @@ -179,6 +180,25 @@ async def test_producer_context_manager(self):
raise ValueError()
assert producer._closed

@run_until_complete
async def test_producer_context_manager_start_error(self):
for target in [
"aiokafka.producer.producer.AIOKafkaClient.bootstrap",
"aiokafka.producer.producer.Sender.start",
]:
producer = AIOKafkaProducer(bootstrap_servers=self.hosts)

# make producer.start() raise
with mock.patch(target, side_effect=RuntimeError("error")):
with self.assertRaises(RuntimeError):
async with producer:
self.fail(f"{target} did not raise")

# should not require calling producer.close()
with warnings.catch_warnings(record=True) as w:
del producer
self.assertEqual(len(w), 0, f"Unexpected warnings: {w}")

@run_until_complete
async def test_producer_send_noack(self):
producer = AIOKafkaProducer(bootstrap_servers=self.hosts, acks=0)
Expand Down