diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index ceca1d9b6..16b3fbb68 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -153,6 +153,7 @@ def send_fetches(self): future = self._client.send(node_id, request, wakeup=False) future.add_callback(self._handle_fetch_response, node_id, fetch_offsets, time.time()) future.add_errback(self._handle_fetch_error, node_id) + future.add_both(self._clear_pending_fetch_request, node_id) futures.append(future) self._fetch_futures.extend(futures) self._clean_done_fetch_futures() @@ -643,36 +644,42 @@ def _create_fetch_requests(self): log.debug("Skipping fetch for partition %s because node %s is throttled", partition, node_id) + elif not self._client.ready(node_id): + # Until we support send request queues, any attempt to send to a not-ready node will be + # immediately failed with NodeNotReadyError. + log.debug("Skipping fetch for partition %s because connection to leader node is not ready yet") + elif node_id in self._nodes_with_pending_fetch_requests: log.debug("Skipping fetch for partition %s because there is a pending fetch request to node %s", partition, node_id) - continue - if version < 5: - partition_info = ( - partition.partition, - position.offset, - self.config['max_partition_fetch_bytes'] - ) - elif version <= 8: - partition_info = ( - partition.partition, - position.offset, - -1, # log_start_offset is used internally by brokers / replicas only - self.config['max_partition_fetch_bytes'], - ) else: - partition_info = ( - partition.partition, - position.leader_epoch, - position.offset, - -1, # log_start_offset is used internally by brokers / replicas only - self.config['max_partition_fetch_bytes'], - ) - - fetchable[node_id][partition] = partition_info - log.debug("Adding fetch request for partition %s at offset %d", - partition, position.offset) + # Leader is connected and does not have a pending fetch request + if version < 5: + partition_info = ( + partition.partition, + position.offset, + self.config['max_partition_fetch_bytes'] + ) + elif version <= 8: + partition_info = ( + partition.partition, + position.offset, + -1, # log_start_offset is used internally by brokers / replicas only + self.config['max_partition_fetch_bytes'], + ) + else: + partition_info = ( + partition.partition, + position.leader_epoch, + position.offset, + -1, # log_start_offset is used internally by brokers / replicas only + self.config['max_partition_fetch_bytes'], + ) + + fetchable[node_id][partition] = partition_info + log.debug("Adding fetch request for partition %s at offset %d", + partition, position.offset) requests = {} for node_id, next_partitions in six.iteritems(fetchable): @@ -761,14 +768,18 @@ def _handle_fetch_response(self, node_id, fetch_offsets, send_time, response): if self._sensors: self._sensors.fetch_latency.record((time.time() - send_time) * 1000) - self._nodes_with_pending_fetch_requests.remove(node_id) def _handle_fetch_error(self, node_id, exception): level = logging.INFO if isinstance(exception, Errors.Cancelled) else logging.ERROR log.log(level, 'Fetch to node %s failed: %s', node_id, exception) if node_id in self._session_handlers: self._session_handlers[node_id].handle_error(exception) - self._nodes_with_pending_fetch_requests.remove(node_id) + + def _clear_pending_fetch_request(self, node_id, _): + try: + self._nodes_with_pending_fetch_requests.remove(node_id) + except KeyError: + pass def _parse_fetched_data(self, completed_fetch): tp = completed_fetch.topic_partition diff --git a/test/test_fetcher.py b/test/test_fetcher.py index 740fa1bab..f4e1f3f73 100644 --- a/test/test_fetcher.py +++ b/test/test_fetcher.py @@ -103,6 +103,7 @@ def test_create_fetch_requests(fetcher, mocker, api_version, fetch_version): fetcher._client._api_versions = BROKER_API_VERSIONS[api_version] mocker.patch.object(fetcher._client.cluster, "leader_for_partition", return_value=0) mocker.patch.object(fetcher._client.cluster, "leader_epoch_for_partition", return_value=0) + mocker.patch.object(fetcher._client, "ready", return_value=True) by_node = fetcher._create_fetch_requests() requests_and_offsets = by_node.values() assert set([r.API_VERSION for (r, _offsets) in requests_and_offsets]) == set([fetch_version])