Skip to content

Commit 4a685d9

Browse files
committed
Fix StreamConsumer wakeup races
1 parent e69c2aa commit 4a685d9

File tree

4 files changed

+117
-66
lines changed

4 files changed

+117
-66
lines changed

src/client.rs

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,21 @@ impl NativeClient {
198198
}
199199
}
200200

201+
pub(crate) enum EventPollResult<T> {
202+
None,
203+
EventConsumed,
204+
Event(T),
205+
}
206+
207+
impl<T> Into<Option<T>> for EventPollResult<T> {
208+
fn into(self) -> Option<T> {
209+
match self {
210+
EventPollResult::None | EventPollResult::EventConsumed => None,
211+
EventPollResult::Event(evt) => Some(evt),
212+
}
213+
}
214+
}
215+
201216
/// A low-level rdkafka client.
202217
///
203218
/// This type is the basis of the consumers and producers in the [`consumer`]
@@ -278,31 +293,42 @@ impl<C: ClientContext> Client<C> {
278293
&self.context
279294
}
280295

281-
pub(crate) fn poll_event(&self, queue: &NativeQueue, timeout: Timeout) -> Option<NativeEvent> {
296+
pub(crate) fn poll_event(
297+
&self,
298+
queue: &NativeQueue,
299+
timeout: Timeout,
300+
) -> EventPollResult<NativeEvent> {
282301
let event = unsafe { NativeEvent::from_ptr(queue.poll(timeout)) };
283302
if let Some(ev) = event {
284303
let evtype = unsafe { rdsys::rd_kafka_event_type(ev.ptr()) };
285304
match evtype {
286-
rdsys::RD_KAFKA_EVENT_LOG => self.handle_log_event(ev.ptr()),
287-
rdsys::RD_KAFKA_EVENT_STATS => self.handle_stats_event(ev.ptr()),
305+
rdsys::RD_KAFKA_EVENT_LOG => {
306+
self.handle_log_event(ev.ptr());
307+
return EventPollResult::EventConsumed;
308+
}
309+
rdsys::RD_KAFKA_EVENT_STATS => {
310+
self.handle_stats_event(ev.ptr());
311+
return EventPollResult::EventConsumed;
312+
}
288313
rdsys::RD_KAFKA_EVENT_ERROR => {
289314
// rdkafka reports consumer errors via RD_KAFKA_EVENT_ERROR but producer errors gets
290315
// embedded on the ack returned via RD_KAFKA_EVENT_DR. Hence we need to return this event
291316
// for the consumer case in order to return the error to the user.
292317
self.handle_error_event(ev.ptr());
293-
return Some(ev);
318+
return EventPollResult::Event(ev);
294319
}
295320
rdsys::RD_KAFKA_EVENT_OAUTHBEARER_TOKEN_REFRESH => {
296321
if C::ENABLE_REFRESH_OAUTH_TOKEN {
297322
self.handle_oauth_refresh_event(ev.ptr());
298323
}
324+
return EventPollResult::EventConsumed;
299325
}
300326
_ => {
301-
return Some(ev);
327+
return EventPollResult::Event(ev);
302328
}
303329
}
304330
}
305-
None
331+
EventPollResult::None
306332
}
307333

308334
fn handle_log_event(&self, event: *mut RDKafkaEvent) {

src/consumer/base_consumer.rs

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use log::{error, warn};
1111
use rdkafka_sys as rdsys;
1212
use rdkafka_sys::types::*;
1313

14-
use crate::client::{Client, NativeClient, NativeQueue};
14+
use crate::client::{Client, EventPollResult, NativeClient, NativeQueue};
1515
use crate::config::{
1616
ClientConfig, FromClientConfig, FromClientConfigAndContext, NativeClientConfig,
1717
};
@@ -115,59 +115,70 @@ where
115115
///
116116
/// The returned message lives in the memory of the consumer and cannot outlive it.
117117
pub fn poll<T: Into<Timeout>>(&self, timeout: T) -> Option<KafkaResult<BorrowedMessage<'_>>> {
118-
self.poll_queue(self.get_queue(), timeout)
118+
self.poll_queue(self.get_queue(), timeout).into()
119119
}
120120

121121
pub(crate) fn poll_queue<T: Into<Timeout>>(
122122
&self,
123123
queue: &NativeQueue,
124124
timeout: T,
125-
) -> Option<KafkaResult<BorrowedMessage<'_>>> {
125+
) -> EventPollResult<KafkaResult<BorrowedMessage<'_>>> {
126126
let now = Instant::now();
127-
let mut timeout = timeout.into();
127+
let initial_timeout = timeout.into();
128+
let mut timeout = initial_timeout;
128129
let min_poll_interval = self.context().main_queue_min_poll_interval();
129130
loop {
130131
let op_timeout = std::cmp::min(timeout, min_poll_interval);
131132
let maybe_event = self.client().poll_event(queue, op_timeout);
132-
if let Some(event) = maybe_event {
133-
let evtype = unsafe { rdsys::rd_kafka_event_type(event.ptr()) };
134-
match evtype {
135-
rdsys::RD_KAFKA_EVENT_FETCH => {
136-
if let Some(result) = self.handle_fetch_event(event) {
137-
return Some(result);
133+
match maybe_event {
134+
EventPollResult::Event(event) => {
135+
let evtype = unsafe { rdsys::rd_kafka_event_type(event.ptr()) };
136+
match evtype {
137+
rdsys::RD_KAFKA_EVENT_FETCH => {
138+
if let Some(result) = self.handle_fetch_event(event) {
139+
return EventPollResult::Event(result);
140+
}
138141
}
139-
}
140-
rdsys::RD_KAFKA_EVENT_ERROR => {
141-
if let Some(err) = self.handle_error_event(event) {
142-
return Some(Err(err));
142+
rdsys::RD_KAFKA_EVENT_ERROR => {
143+
if let Some(err) = self.handle_error_event(event) {
144+
return EventPollResult::Event(Err(err));
145+
}
143146
}
144-
}
145-
rdsys::RD_KAFKA_EVENT_REBALANCE => {
146-
self.handle_rebalance_event(event);
147-
if timeout != Timeout::Never {
148-
return None;
147+
rdsys::RD_KAFKA_EVENT_REBALANCE => {
148+
self.handle_rebalance_event(event);
149+
if timeout != Timeout::Never {
150+
return EventPollResult::EventConsumed;
151+
}
149152
}
150-
}
151-
rdsys::RD_KAFKA_EVENT_OFFSET_COMMIT => {
152-
self.handle_offset_commit_event(event);
153-
if timeout != Timeout::Never {
154-
return None;
153+
rdsys::RD_KAFKA_EVENT_OFFSET_COMMIT => {
154+
self.handle_offset_commit_event(event);
155+
if timeout != Timeout::Never {
156+
return EventPollResult::EventConsumed;
157+
}
158+
}
159+
_ => {
160+
let buf = unsafe {
161+
let evname = rdsys::rd_kafka_event_name(event.ptr());
162+
CStr::from_ptr(evname).to_bytes()
163+
};
164+
let evname = String::from_utf8(buf.to_vec()).unwrap();
165+
warn!("Ignored event '{}' on consumer poll", evname);
155166
}
156167
}
157-
_ => {
158-
let evname = unsafe {
159-
let evname = rdsys::rd_kafka_event_name(event.ptr());
160-
CStr::from_ptr(evname).to_string_lossy()
161-
};
162-
warn!("Ignored event '{evname}' on consumer poll");
168+
}
169+
EventPollResult::None => {
170+
timeout = initial_timeout.saturating_sub(now.elapsed());
171+
if timeout.is_zero() {
172+
return EventPollResult::None;
163173
}
164174
}
165-
}
166-
167-
timeout = timeout.saturating_sub(now.elapsed());
168-
if timeout.is_zero() {
169-
return None;
170-
}
175+
EventPollResult::EventConsumed => {
176+
timeout = initial_timeout.saturating_sub(now.elapsed());
177+
if timeout.is_zero() {
178+
return EventPollResult::EventConsumed;
179+
}
180+
}
181+
};
171182
}
172183
}
173184

@@ -802,7 +813,7 @@ where
802813
/// associated consumer regularly, even if no messages are expected, to
803814
/// serve events.
804815
pub fn poll<T: Into<Timeout>>(&self, timeout: T) -> Option<KafkaResult<BorrowedMessage<'_>>> {
805-
self.consumer.poll_queue(&self.queue, timeout)
816+
self.consumer.poll_queue(&self.queue, timeout).into()
806817
}
807818

808819
/// Sets a callback that will be invoked whenever the queue becomes

src/consumer/stream_consumer.rs

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use slab::Slab;
1818
use rdkafka_sys as rdsys;
1919
use rdkafka_sys::types::*;
2020

21-
use crate::client::{Client, NativeQueue};
21+
use crate::client::{Client, EventPollResult, NativeQueue};
2222
use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
2323
use crate::consumer::base_consumer::{BaseConsumer, PartitionQueue};
2424
use crate::consumer::{
@@ -122,11 +122,12 @@ impl<'a, C: ConsumerContext> MessageStream<'a, C> {
122122
}
123123
}
124124

125-
fn poll(&self) -> Option<KafkaResult<BorrowedMessage<'a>>> {
125+
fn poll(&self) -> EventPollResult<KafkaResult<BorrowedMessage<'a>>> {
126126
if let Some(queue) = self.partition_queue {
127127
self.consumer.poll_queue(queue, Duration::ZERO)
128128
} else {
129-
self.consumer.poll(Duration::ZERO)
129+
self.consumer
130+
.poll_queue(self.consumer.get_queue(), Duration::ZERO)
130131
}
131132
}
132133
}
@@ -135,25 +136,38 @@ impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> {
135136
type Item = KafkaResult<BorrowedMessage<'a>>;
136137

137138
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138-
// If there is a message ready, yield it immediately to avoid the
139-
// taking the lock in `self.set_waker`.
140-
if let Some(message) = self.poll() {
141-
return Poll::Ready(Some(message));
142-
}
143-
144-
// Otherwise, we need to wait for a message to become available. Store
145-
// the waker so that we are woken up if the queue flips from non-empty
146-
// to empty. We have to store the waker repatedly in case this future
147-
// migrates between tasks.
148-
self.wakers.set_waker(self.slot, cx.waker().clone());
149-
150-
// Check whether a new message became available after we installed the
151-
// waker. This avoids a race where `poll` returns None to indicate that
152-
// the queue is empty, but the queue becomes non-empty before we've
153-
// installed the waker.
154139
match self.poll() {
155-
None => Poll::Pending,
156-
Some(message) => Poll::Ready(Some(message)),
140+
EventPollResult::Event(message) => {
141+
// If there is a message ready, yield it immediately to avoid the
142+
// taking the lock in `self.set_waker`.
143+
Poll::Ready(Some(message))
144+
}
145+
EventPollResult::EventConsumed => {
146+
// Event was consumed, yield to runtime
147+
cx.waker().wake_by_ref();
148+
Poll::Pending
149+
}
150+
EventPollResult::None => {
151+
// Otherwise, we need to wait for a message to become available. Store
152+
// the waker so that we are woken up if the queue flips from non-empty
153+
// to empty. We have to store the waker repatedly in case this future
154+
// migrates between tasks.
155+
self.wakers.set_waker(self.slot, cx.waker().clone());
156+
157+
// Check whether a new message became available after we installed the
158+
// waker. This avoids a race where `poll` returns None to indicate that
159+
// the queue is empty, but the queue becomes non-empty before we've
160+
// installed the waker.
161+
match self.poll() {
162+
EventPollResult::Event(message) => Poll::Ready(Some(message)),
163+
EventPollResult::EventConsumed => {
164+
// Event was consumed, yield to runtime
165+
cx.waker().wake_by_ref();
166+
Poll::Pending
167+
}
168+
EventPollResult::None => Poll::Pending,
169+
}
170+
}
157171
}
158172
}
159173
}

src/producer/base_producer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ use rdkafka_sys as rdsys;
5757
use rdkafka_sys::rd_kafka_vtype_t::*;
5858
use rdkafka_sys::types::*;
5959

60-
use crate::client::{Client, NativeQueue};
60+
use crate::client::{Client, EventPollResult, NativeQueue};
6161
use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
6262
use crate::consumer::ConsumerGroupMetadata;
6363
use crate::error::{IsError, KafkaError, KafkaResult, RDKafkaError};
@@ -363,7 +363,7 @@ where
363363
/// the message delivery callbacks.
364364
pub fn poll<T: Into<Timeout>>(&self, timeout: T) {
365365
let event = self.client().poll_event(&self.queue, timeout.into());
366-
if let Some(ev) = event {
366+
if let EventPollResult::Event(ev) = event {
367367
let evtype = unsafe { rdsys::rd_kafka_event_type(ev.ptr()) };
368368
match evtype {
369369
rdsys::RD_KAFKA_EVENT_DR => self.handle_delivery_report_event(ev),

0 commit comments

Comments
 (0)