From 8dfdc96a92b4c9f9dbaf827efd57801f92ee0fc4 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Fri, 2 May 2025 12:48:05 -0700 Subject: [PATCH 1/3] Fix timeout handling so consumer.poll(timeout_ms=0) makes progress --- kafka/consumer/group.py | 36 +++++++++++------------ kafka/coordinator/base.py | 55 ++++++++++++++++++++++------------- kafka/coordinator/consumer.py | 33 ++++++++++++++++----- kafka/util.py | 44 ++++++++++++++++++++++++++-- 4 files changed, 119 insertions(+), 49 deletions(-) diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py index 471ae5cda..ce3cf9203 100644 --- a/kafka/consumer/group.py +++ b/kafka/consumer/group.py @@ -18,7 +18,7 @@ from kafka.metrics import MetricConfig, Metrics from kafka.protocol.list_offsets import OffsetResetStrategy from kafka.structs import OffsetAndMetadata, TopicPartition -from kafka.util import timeout_ms_fn +from kafka.util import Timer from kafka.version import __version__ log = logging.getLogger(__name__) @@ -679,41 +679,40 @@ def poll(self, timeout_ms=0, max_records=None, update_offsets=True): assert not self._closed, 'KafkaConsumer is closed' # Poll for new data until the timeout expires - inner_timeout_ms = timeout_ms_fn(timeout_ms, None) + timer = Timer(timeout_ms) while not self._closed: - records = self._poll_once(inner_timeout_ms(), max_records, update_offsets=update_offsets) + records = self._poll_once(timer, max_records, update_offsets=update_offsets) if records: return records - - if inner_timeout_ms() <= 0: + elif timer.expired: break - return {} - def _poll_once(self, timeout_ms, max_records, update_offsets=True): + def _poll_once(self, timer, max_records, update_offsets=True): """Do one round of polling. In addition to checking for new data, this does any needed heart-beating, auto-commits, and offset updates. Arguments: - timeout_ms (int): The maximum time in milliseconds to block. + timer (Timer): The maximum time in milliseconds to block. Returns: dict: Map of topic to list of records (may be empty). """ - inner_timeout_ms = timeout_ms_fn(timeout_ms, None) - if not self._coordinator.poll(timeout_ms=inner_timeout_ms()): + if not self._coordinator.poll(timeout_ms=timer.timeout_ms): return {} - has_all_fetch_positions = self._update_fetch_positions(timeout_ms=inner_timeout_ms()) + has_all_fetch_positions = self._update_fetch_positions(timeout_ms=timer.timeout_ms) # If data is available already, e.g. from a previous network client # poll() call to commit, then just return it immediately records, partial = self._fetcher.fetched_records(max_records, update_offsets=update_offsets) + log.debug('Fetched records: %s, %s', records, partial) # Before returning the fetched records, we can send off the # next round of fetches and avoid block waiting for their # responses to enable pipelining while the user is handling the # fetched records. if not partial: + log.debug("Sending fetches") futures = self._fetcher.send_fetches() if len(futures): self._client.poll(timeout_ms=0) @@ -723,7 +722,7 @@ def _poll_once(self, timeout_ms, max_records, update_offsets=True): # We do not want to be stuck blocking in poll if we are missing some positions # since the offset lookup may be backing off after a failure - poll_timeout_ms = inner_timeout_ms(self._coordinator.time_to_next_poll() * 1000) + poll_timeout_ms = min(timer.timeout_ms, self._coordinator.time_to_next_poll() * 1000) if not has_all_fetch_positions: poll_timeout_ms = min(poll_timeout_ms, self.config['retry_backoff_ms']) @@ -749,15 +748,14 @@ def position(self, partition, timeout_ms=None): raise TypeError('partition must be a TopicPartition namedtuple') assert self._subscription.is_assigned(partition), 'Partition is not assigned' - inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout retrieving partition position') + timer = Timer(timeout_ms) position = self._subscription.assignment[partition].position - try: - while position is None: - # batch update fetch positions for any partitions without a valid position - self._update_fetch_positions(timeout_ms=inner_timeout_ms()) + while position is None: + # batch update fetch positions for any partitions without a valid position + if self._update_fetch_positions(timeout_ms=timer.timeout_ms): position = self._subscription.assignment[partition].position - except KafkaTimeoutError: - return None + elif timer.expired: + return None else: return position.offset diff --git a/kafka/coordinator/base.py b/kafka/coordinator/base.py index 4aa5c89bc..e2e4fba95 100644 --- a/kafka/coordinator/base.py +++ b/kafka/coordinator/base.py @@ -16,7 +16,7 @@ from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.find_coordinator import FindCoordinatorRequest from kafka.protocol.group import HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, SyncGroupRequest, DEFAULT_GENERATION_ID, UNKNOWN_MEMBER_ID -from kafka.util import timeout_ms_fn +from kafka.util import timeout_ms_fn, Timer log = logging.getLogger('kafka.coordinator') @@ -256,9 +256,9 @@ def ensure_coordinator_ready(self, timeout_ms=None): timeout_ms (numeric, optional): Maximum number of milliseconds to block waiting to find coordinator. Default: None. - Raises: KafkaTimeoutError if timeout_ms is not None + Returns: True is coordinator found before timeout_ms, else False """ - inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to find group coordinator') + timer = Timer(timeout_ms) with self._client._lock, self._lock: while self.coordinator_unknown(): @@ -272,27 +272,34 @@ def ensure_coordinator_ready(self, timeout_ms=None): else: self.coordinator_id = maybe_coordinator_id self._client.maybe_connect(self.coordinator_id) - continue + if timer.expired: + return False + else: + continue else: future = self.lookup_coordinator() - self._client.poll(future=future, timeout_ms=inner_timeout_ms()) + self._client.poll(future=future, timeout_ms=timer.timeout_ms) if not future.is_done: - raise Errors.KafkaTimeoutError() + return False if future.failed(): if future.retriable(): if getattr(future.exception, 'invalid_metadata', False): log.debug('Requesting metadata for group coordinator request: %s', future.exception) metadata_update = self._client.cluster.request_update() - self._client.poll(future=metadata_update, timeout_ms=inner_timeout_ms()) + self._client.poll(future=metadata_update, timeout_ms=timer.timeout_ms) if not metadata_update.is_done: - raise Errors.KafkaTimeoutError() + return False else: - time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000) + time.sleep(min(timer.timeout_ms, self.config['retry_backoff_ms']) / 1000) else: raise future.exception # pylint: disable-msg=raising-bad-type + if timer.expired: + return False + else: + return True def _reset_find_coordinator_future(self, result): self._find_coordinator_future = None @@ -407,21 +414,23 @@ def ensure_active_group(self, timeout_ms=None): timeout_ms (numeric, optional): Maximum number of milliseconds to block waiting to join group. Default: None. - Raises: KafkaTimeoutError if timeout_ms is not None + Returns: True if group initialized before timeout_ms, else False """ if self.config['api_version'] < (0, 9): raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker') - inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group') - self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms()) + timer = Timer(timeout_ms) + if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): + return False self._start_heartbeat_thread() - self.join_group(timeout_ms=inner_timeout_ms()) + return self.join_group(timeout_ms=timer.timeout_ms) def join_group(self, timeout_ms=None): if self.config['api_version'] < (0, 9): raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker') - inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group') + timer = Timer(timeout_ms) while self.need_rejoin(): - self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms()) + if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): + return False # call on_join_prepare if needed. We set a flag # to make sure that we do not call it a second @@ -434,7 +443,7 @@ def join_group(self, timeout_ms=None): if not self.rejoining: self._on_join_prepare(self._generation.generation_id, self._generation.member_id, - timeout_ms=inner_timeout_ms()) + timeout_ms=timer.timeout_ms) self.rejoining = True # fence off the heartbeat thread explicitly so that it cannot @@ -449,16 +458,18 @@ def join_group(self, timeout_ms=None): while not self.coordinator_unknown(): if not self._client.in_flight_request_count(self.coordinator_id): break - self._client.poll(timeout_ms=inner_timeout_ms(200)) + self._client.poll(timeout_ms=min(timer.timeout_ms, 200)) + if timer.expired: + return False else: continue future = self._initiate_join_group() - self._client.poll(future=future, timeout_ms=inner_timeout_ms()) + self._client.poll(future=future, timeout_ms=timer.timeout_ms) if future.is_done: self._reset_join_group_future() else: - raise Errors.KafkaTimeoutError() + return False if future.succeeded(): self.rejoining = False @@ -467,6 +478,7 @@ def join_group(self, timeout_ms=None): self._generation.member_id, self._generation.protocol, future.value) + return True else: exception = future.exception if isinstance(exception, (Errors.UnknownMemberIdError, @@ -476,7 +488,10 @@ def join_group(self, timeout_ms=None): continue elif not future.retriable(): raise exception # pylint: disable-msg=raising-bad-type - time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000) + elif timer.expired: + return False + else: + time.sleep(min(timer.timeout_ms, self.config['retry_backoff_ms']) / 1000) def _send_join_group_request(self): """Join the group and return the assignment for the next generation. diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py index d4943da31..854d5e14b 100644 --- a/kafka/coordinator/consumer.py +++ b/kafka/coordinator/consumer.py @@ -19,7 +19,7 @@ from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest from kafka.structs import OffsetAndMetadata, TopicPartition -from kafka.util import timeout_ms_fn, WeakMethod +from kafka.util import timeout_ms_fn, Timer, WeakMethod log = logging.getLogger(__name__) @@ -95,6 +95,7 @@ def __init__(self, client, subscription, **configs): self.auto_commit_interval = self.config['auto_commit_interval_ms'] / 1000 self.next_auto_commit_deadline = None self.completed_offset_commits = collections.deque() + self._offset_fetch_futures = dict() if self.config['default_offset_commit_callback'] is None: self.config['default_offset_commit_callback'] = self._default_offset_commit_callback @@ -269,10 +270,11 @@ def poll(self, timeout_ms=None): if self.group_id is None: return True - inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout in coordinator.poll') + timer = Timer(timeout_ms) try: self._invoke_completed_offset_commit_callbacks() - self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms()) + if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): + return False if self.config['api_version'] >= (0, 9) and self._subscription.partitions_auto_assigned(): if self.need_rejoin(): @@ -289,9 +291,12 @@ def poll(self, timeout_ms=None): # description of the problem. if self._subscription.subscribed_pattern: metadata_update = self._client.cluster.request_update() - self._client.poll(future=metadata_update, timeout_ms=inner_timeout_ms()) + self._client.poll(future=metadata_update, timeout_ms=timer.timeout_ms) + if not metadata_update.is_done: + return False - self.ensure_active_group(timeout_ms=inner_timeout_ms()) + if not self.ensure_active_group(timeout_ms=timer.timeout_ms): + return False self.poll_heartbeat() @@ -395,10 +400,14 @@ def need_rejoin(self): def refresh_committed_offsets_if_needed(self, timeout_ms=None): """Fetch committed offsets for assigned partitions.""" missing_fetch_positions = set(self._subscription.missing_fetch_positions()) - offsets = self.fetch_committed_offsets(missing_fetch_positions, timeout_ms=timeout_ms) + try: + offsets = self.fetch_committed_offsets(missing_fetch_positions, timeout_ms=timeout_ms) + except Errors.KafkaTimeoutError: + return False for partition, offset in six.iteritems(offsets): log.debug("Setting offset for partition %s to the committed offset %s", partition, offset.offset); self._subscription.seek(partition, offset.offset) + return True def fetch_committed_offsets(self, partitions, timeout_ms=None): """Fetch the current committed offsets for specified partitions @@ -415,16 +424,24 @@ def fetch_committed_offsets(self, partitions, timeout_ms=None): if not partitions: return {} - inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout in coordinator.fetch_committed_offsets') + inner_timeout_ms = timeout_ms_fn(timeout_ms, None) while True: self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms()) # contact coordinator to fetch committed offsets - future = self._send_offset_fetch_request(partitions) + future_key = frozenset(partitions) + if future_key in self._offset_fetch_futures: + future = self._offset_fetch_futures[future_key] + else: + future = self._send_offset_fetch_request(partitions) + self._offset_fetch_futures[future_key] = future + self._client.poll(future=future, timeout_ms=inner_timeout_ms()) if not future.is_done: raise Errors.KafkaTimeoutError() + else: + del self._offset_fetch_futures[future_key] if future.succeeded(): return future.value diff --git a/kafka/util.py b/kafka/util.py index 470200b1b..bd44d0a47 100644 --- a/kafka/util.py +++ b/kafka/util.py @@ -1,4 +1,4 @@ -from __future__ import absolute_import +from __future__ import absolute_import, division import binascii import re @@ -28,14 +28,19 @@ def crc32(data): def timeout_ms_fn(timeout_ms, error_message): elapsed = 0.0 # noqa: F841 begin = time.time() + raise_next = False def inner_timeout_ms(fallback=None): + nonlocal elapsed, begin, raise_next if timeout_ms is None: return fallback elapsed = (time.time() - begin) * 1000 if elapsed >= timeout_ms: - if error_message is not None: + if error_message is None: + return 0 + elif raise_next: raise KafkaTimeoutError(error_message) else: + raise_next = True return 0 ret = max(0, timeout_ms - elapsed) if fallback is not None: @@ -44,6 +49,41 @@ def inner_timeout_ms(fallback=None): return inner_timeout_ms +class Timer: + __slots__ = ('_start_at', '_expire_at', '_timeout_ms', '_error_message') + + def __init__(self, timeout_ms, error_message=None, start_at=None): + self._timeout_ms = timeout_ms + self._start_at = start_at or time.time() + if timeout_ms is not None: + self._expire_at = self._start_at + timeout_ms / 1000 + else: + self._expire_at = float('inf') + self._error_message = error_message + + @property + def expired(self): + return time.time() >= self._expire_at + + @property + def timeout_ms(self): + if self._timeout_ms is None: + return None + elif self._expire_at == float('inf'): + return float('inf') + remaining = self._expire_at - time.time() + if remaining < 0: + return 0 + else: + return int(remaining * 1000) + + def maybe_raise(self): + if self.expired: + raise KafkaTimeoutError(self._error_message) + + def __str__(self): + return "Timer(%s ms remaining)" % (self.timeout_ms) + # Taken from: https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java#L29 TOPIC_MAX_LENGTH = 249 TOPIC_LEGAL_CHARS = re.compile('^[a-zA-Z0-9._-]+$') From 5f3b96670776a293d36ca42bf045650bb438b980 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sat, 3 May 2025 16:13:35 -0700 Subject: [PATCH 2/3] Use kafka.util.Timer --- kafka/client_async.py | 21 +++++++------ kafka/consumer/fetcher.py | 15 ++++++---- kafka/coordinator/base.py | 15 +++++++--- kafka/coordinator/consumer.py | 56 +++++++++++++++++++---------------- kafka/producer/kafka.py | 36 ++++++++-------------- kafka/util.py | 28 +++--------------- 6 files changed, 78 insertions(+), 93 deletions(-) diff --git a/kafka/client_async.py b/kafka/client_async.py index 448a995ba..79635267b 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -27,7 +27,7 @@ from kafka.metrics.stats.rate import TimeUnit from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS from kafka.protocol.metadata import MetadataRequest -from kafka.util import Dict, WeakMethod, ensure_valid_topic_name, timeout_ms_fn +from kafka.util import Dict, Timer, WeakMethod, ensure_valid_topic_name # Although this looks unused, it actually monkey-patches socket.socketpair() # and should be left in as long as we're using socket.socketpair() in this file from kafka.vendor import socketpair # noqa: F401 @@ -646,11 +646,11 @@ def poll(self, timeout_ms=None, future=None): if not isinstance(timeout_ms, (int, float, type(None))): raise TypeError('Invalid type for timeout: %s' % type(timeout_ms)) - begin = time.time() if timeout_ms is not None: - timeout_at = begin + (timeout_ms / 1000) + timer = Timer(timeout_ms) else: - timeout_at = begin + (self.config['request_timeout_ms'] / 1000) + timer = Timer(self.config['request_timeout_ms']) + # Loop for futures, break after first loop if None responses = [] while True: @@ -675,12 +675,11 @@ def poll(self, timeout_ms=None, future=None): if future is not None and future.is_done: timeout = 0 else: - user_timeout_ms = 1000 * max(0, timeout_at - time.time()) idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms() request_timeout_ms = self._next_ifr_request_timeout_ms() - log.debug("Timeouts: user %f, metadata %f, idle connection %f, request %f", user_timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms) + log.debug("Timeouts: user %f, metadata %f, idle connection %f, request %f", timer.timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms) timeout = min( - user_timeout_ms, + timer.timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms) @@ -698,7 +697,7 @@ def poll(self, timeout_ms=None, future=None): break elif future.is_done: break - elif timeout_ms is not None and time.time() >= timeout_at: + elif timeout_ms is not None and timer.expired: break return responses @@ -1175,16 +1174,16 @@ def await_ready(self, node_id, timeout_ms=30000): This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with care. """ - inner_timeout_ms = timeout_ms_fn(timeout_ms, None) + timer = Timer(timeout_ms) self.poll(timeout_ms=0) if self.is_ready(node_id): return True - while not self.is_ready(node_id) and inner_timeout_ms() > 0: + while not self.is_ready(node_id) and not timer.expired: if self.connection_failed(node_id): raise Errors.KafkaConnectionError("Connection to %s failed." % (node_id,)) self.maybe_connect(node_id) - self.poll(timeout_ms=inner_timeout_ms()) + self.poll(timeout_ms=timer.timeout_ms) return self.is_ready(node_id) def send_and_receive(self, node_id, request): diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index e7757e7b3..42e2d660c 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -19,7 +19,7 @@ from kafka.record import MemoryRecords from kafka.serializer import Deserializer from kafka.structs import TopicPartition, OffsetAndMetadata, OffsetAndTimestamp -from kafka.util import timeout_ms_fn +from kafka.util import Timer log = logging.getLogger(__name__) @@ -230,7 +230,7 @@ def _fetch_offsets_by_times(self, timestamps, timeout_ms=None): if not timestamps: return {} - inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout fetching offsets') + timer = Timer(timeout_ms, "Failed to get offsets by timestamps in %s ms" % (timeout_ms,)) timestamps = copy.copy(timestamps) fetched_offsets = dict() while True: @@ -238,7 +238,7 @@ def _fetch_offsets_by_times(self, timestamps, timeout_ms=None): return {} future = self._send_list_offsets_requests(timestamps) - self._client.poll(future=future, timeout_ms=inner_timeout_ms()) + self._client.poll(future=future, timeout_ms=timer.timeout_ms) # Timeout w/o future completion if not future.is_done: @@ -256,12 +256,17 @@ def _fetch_offsets_by_times(self, timestamps, timeout_ms=None): if future.exception.invalid_metadata or self._client.cluster.need_update: refresh_future = self._client.cluster.request_update() - self._client.poll(future=refresh_future, timeout_ms=inner_timeout_ms()) + self._client.poll(future=refresh_future, timeout_ms=timer.timeout_ms) if not future.is_done: break else: - time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000) + if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: + time.sleep(self.config['retry_backoff_ms'] / 1000) + else: + time.sleep(timer.timeout_ms / 1000) + + timer.maybe_raise() raise Errors.KafkaTimeoutError( "Failed to get offsets by timestamps in %s ms" % (timeout_ms,)) diff --git a/kafka/coordinator/base.py b/kafka/coordinator/base.py index e2e4fba95..1592f9154 100644 --- a/kafka/coordinator/base.py +++ b/kafka/coordinator/base.py @@ -16,7 +16,7 @@ from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.find_coordinator import FindCoordinatorRequest from kafka.protocol.group import HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, SyncGroupRequest, DEFAULT_GENERATION_ID, UNKNOWN_MEMBER_ID -from kafka.util import timeout_ms_fn, Timer +from kafka.util import Timer log = logging.getLogger('kafka.coordinator') @@ -293,7 +293,10 @@ def ensure_coordinator_ready(self, timeout_ms=None): if not metadata_update.is_done: return False else: - time.sleep(min(timer.timeout_ms, self.config['retry_backoff_ms']) / 1000) + if timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: + time.sleep(self.config['retry_backoff_ms'] / 1000) + else: + time.sleep(timer.timeout_ms / 1000) else: raise future.exception # pylint: disable-msg=raising-bad-type if timer.expired: @@ -458,7 +461,8 @@ def join_group(self, timeout_ms=None): while not self.coordinator_unknown(): if not self._client.in_flight_request_count(self.coordinator_id): break - self._client.poll(timeout_ms=min(timer.timeout_ms, 200)) + poll_timeout_ms = 200 if timer.timeout_ms is None or timer.timeout_ms > 200 else timer.timeout_ms + self._client.poll(timeout_ms=poll_timeout_ms) if timer.expired: return False else: @@ -491,7 +495,10 @@ def join_group(self, timeout_ms=None): elif timer.expired: return False else: - time.sleep(min(timer.timeout_ms, self.config['retry_backoff_ms']) / 1000) + if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: + time.sleep(self.config['retry_backoff_ms'] / 1000) + else: + time.sleep(timer.timeout_ms / 1000) def _send_join_group_request(self): """Join the group and return the assignment for the next generation. diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py index 854d5e14b..4361b3dc3 100644 --- a/kafka/coordinator/consumer.py +++ b/kafka/coordinator/consumer.py @@ -19,7 +19,7 @@ from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest from kafka.structs import OffsetAndMetadata, TopicPartition -from kafka.util import timeout_ms_fn, Timer, WeakMethod +from kafka.util import Timer, WeakMethod log = logging.getLogger(__name__) @@ -405,7 +405,7 @@ def refresh_committed_offsets_if_needed(self, timeout_ms=None): except Errors.KafkaTimeoutError: return False for partition, offset in six.iteritems(offsets): - log.debug("Setting offset for partition %s to the committed offset %s", partition, offset.offset); + log.debug("Setting offset for partition %s to the committed offset %s", partition, offset.offset) self._subscription.seek(partition, offset.offset) return True @@ -424,32 +424,35 @@ def fetch_committed_offsets(self, partitions, timeout_ms=None): if not partitions: return {} - inner_timeout_ms = timeout_ms_fn(timeout_ms, None) + future_key = frozenset(partitions) + timer = Timer(timeout_ms) while True: - self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms()) + self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms) # contact coordinator to fetch committed offsets - future_key = frozenset(partitions) if future_key in self._offset_fetch_futures: future = self._offset_fetch_futures[future_key] else: future = self._send_offset_fetch_request(partitions) self._offset_fetch_futures[future_key] = future - self._client.poll(future=future, timeout_ms=inner_timeout_ms()) + self._client.poll(future=future, timeout_ms=timer.timeout_ms) - if not future.is_done: - raise Errors.KafkaTimeoutError() - else: + if future.is_done: del self._offset_fetch_futures[future_key] - if future.succeeded(): - return future.value + if future.succeeded(): + return future.value - if not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type + elif not future.retriable(): + raise future.exception # pylint: disable-msg=raising-bad-type - time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000) + # future failed but is retriable, or is not done yet + if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: + time.sleep(self.config['retry_backoff_ms'] / 1000) + else: + time.sleep(timer.timeout_ms / 1000) + timer.maybe_raise() def close(self, autocommit=True, timeout_ms=None): """Close the coordinator, leave the current group, @@ -540,23 +543,26 @@ def commit_offsets_sync(self, offsets, timeout_ms=None): if not offsets: return - inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout in coordinator.poll') + timer = Timer(timeout_ms) while True: - self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms()) + self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms) future = self._send_offset_commit_request(offsets) - self._client.poll(future=future, timeout_ms=inner_timeout_ms()) - - if not future.is_done: - raise Errors.KafkaTimeoutError() + self._client.poll(future=future, timeout_ms=timer.timeout_ms) - if future.succeeded(): - return future.value + if future.is_done: + if future.succeeded(): + return future.value - if not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type + elif not future.retriable(): + raise future.exception # pylint: disable-msg=raising-bad-type - time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000) + # future failed but is retriable, or it is still pending + if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: + time.sleep(self.config['retry_backoff_ms'] / 1000) + else: + time.sleep(timer.timeout_ms / 1000) + timer.maybe_raise() def _maybe_auto_commit_offsets_sync(self, timeout_ms=None): if self.config['enable_auto_commit']: diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index 6861ec93a..66208bbe1 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -5,7 +5,6 @@ import logging import socket import threading -import time import warnings import weakref @@ -24,7 +23,7 @@ from kafka.record.legacy_records import LegacyRecordBatchBuilder from kafka.serializer import Serializer from kafka.structs import TopicPartition -from kafka.util import ensure_valid_topic_name +from kafka.util import Timer, ensure_valid_topic_name log = logging.getLogger(__name__) @@ -664,8 +663,7 @@ def __getattr__(self, name): def partitions_for(self, topic): """Returns set of all known partitions for the topic.""" - max_wait = self.config['max_block_ms'] / 1000 - return self._wait_on_metadata(topic, max_wait) + return self._wait_on_metadata(topic, self.config['max_block_ms']) @classmethod def max_usable_produce_magic(cls, api_version): @@ -835,14 +833,11 @@ def send(self, topic, value=None, key=None, headers=None, partition=None, timest assert not (value is None and key is None), 'Need at least one: key or value' ensure_valid_topic_name(topic) key_bytes = value_bytes = None + timer = Timer(self.config['max_block_ms'], "Failed to assign partition for message in max_block_ms.") try: assigned_partition = None - elapsed = 0.0 - begin = time.time() - timeout = self.config['max_block_ms'] / 1000 - while assigned_partition is None and elapsed < timeout: - elapsed = time.time() - begin - self._wait_on_metadata(topic, timeout - elapsed) + while assigned_partition is None and not timer.expired: + self._wait_on_metadata(topic, timer.timeout_ms) key_bytes = self._serialize( self.config['key_serializer'], @@ -856,7 +851,7 @@ def send(self, topic, value=None, key=None, headers=None, partition=None, timest assigned_partition = self._partition(topic, partition, key, value, key_bytes, value_bytes) if assigned_partition is None: - raise Errors.KafkaTimeoutError("Failed to assign partition for message after %s secs." % timeout) + raise Errors.KafkaTimeoutError("Failed to assign partition for message after %s secs." % timer.elapsed_ms / 1000) else: partition = assigned_partition @@ -931,7 +926,7 @@ def _ensure_valid_record_size(self, size): " the maximum request size you have configured with the" " max_request_size configuration" % (size,)) - def _wait_on_metadata(self, topic, max_wait): + def _wait_on_metadata(self, topic, max_wait_ms): """ Wait for cluster metadata including partitions for the given topic to be available. @@ -949,36 +944,29 @@ def _wait_on_metadata(self, topic, max_wait): """ # add topic to metadata topic list if it is not there already. self._sender.add_topic(topic) - begin = time.time() - elapsed = 0.0 + timer = Timer(max_wait_ms, "Failed to update metadata after %.1f secs." % (max_wait_ms * 1000,)) metadata_event = None while True: partitions = self._metadata.partitions_for_topic(topic) if partitions is not None: return partitions - - if elapsed >= max_wait: - raise Errors.KafkaTimeoutError( - "Failed to update metadata after %.1f secs." % (max_wait,)) - + timer.maybe_raise() if not metadata_event: metadata_event = threading.Event() log.debug("%s: Requesting metadata update for topic %s", str(self), topic) - metadata_event.clear() future = self._metadata.request_update() future.add_both(lambda e, *args: e.set(), metadata_event) self._sender.wakeup() - metadata_event.wait(max_wait - elapsed) + metadata_event.wait(timer.timeout_ms / 1000) if not metadata_event.is_set(): raise Errors.KafkaTimeoutError( - "Failed to update metadata after %.1f secs." % (max_wait,)) + "Failed to update metadata after %.1f secs." % (max_wait_ms * 1000,)) elif topic in self._metadata.unauthorized_topics: raise Errors.TopicAuthorizationFailedError(set([topic])) else: - elapsed = time.time() - begin - log.debug("%s: _wait_on_metadata woke after %s secs.", str(self), elapsed) + log.debug("%s: _wait_on_metadata woke after %s secs.", str(self), timer.elapsed_ms / 1000) def _serialize(self, f, topic, data): if not f: diff --git a/kafka/util.py b/kafka/util.py index bd44d0a47..bfb9365ad 100644 --- a/kafka/util.py +++ b/kafka/util.py @@ -25,30 +25,6 @@ def crc32(data): from binascii import crc32 # noqa: F401 -def timeout_ms_fn(timeout_ms, error_message): - elapsed = 0.0 # noqa: F841 - begin = time.time() - raise_next = False - def inner_timeout_ms(fallback=None): - nonlocal elapsed, begin, raise_next - if timeout_ms is None: - return fallback - elapsed = (time.time() - begin) * 1000 - if elapsed >= timeout_ms: - if error_message is None: - return 0 - elif raise_next: - raise KafkaTimeoutError(error_message) - else: - raise_next = True - return 0 - ret = max(0, timeout_ms - elapsed) - if fallback is not None: - return min(ret, fallback) - return ret - return inner_timeout_ms - - class Timer: __slots__ = ('_start_at', '_expire_at', '_timeout_ms', '_error_message') @@ -77,6 +53,10 @@ def timeout_ms(self): else: return int(remaining * 1000) + @property + def elapsed_ms(self): + return int(1000 * (time.time() - self._start_at)) + def maybe_raise(self): if self.expired: raise KafkaTimeoutError(self._error_message) From 2d98d0d4bc6c8a6b00164f10f1f80de630dbd73d Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sat, 3 May 2025 19:02:57 -0700 Subject: [PATCH 3/3] Dont use timer in client.poll() unless user provides timeout_ms --- kafka/client_async.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/kafka/client_async.py b/kafka/client_async.py index 79635267b..7d466574f 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -645,11 +645,7 @@ def poll(self, timeout_ms=None, future=None): """ if not isinstance(timeout_ms, (int, float, type(None))): raise TypeError('Invalid type for timeout: %s' % type(timeout_ms)) - - if timeout_ms is not None: - timer = Timer(timeout_ms) - else: - timer = Timer(self.config['request_timeout_ms']) + timer = Timer(timeout_ms) # Loop for futures, break after first loop if None responses = [] @@ -675,11 +671,12 @@ def poll(self, timeout_ms=None, future=None): if future is not None and future.is_done: timeout = 0 else: + user_timeout_ms = timer.timeout_ms if timeout_ms is not None else self.config['request_timeout_ms'] idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms() request_timeout_ms = self._next_ifr_request_timeout_ms() - log.debug("Timeouts: user %f, metadata %f, idle connection %f, request %f", timer.timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms) + log.debug("Timeouts: user %f, metadata %f, idle connection %f, request %f", user_timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms) timeout = min( - timer.timeout_ms, + user_timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms)