diff --git a/src/producer/future_producer.rs b/src/producer/future_producer.rs index fc7b90390..ba1ae93c3 100644 --- a/src/producer/future_producer.rs +++ b/src/producer/future_producer.rs @@ -33,6 +33,22 @@ use super::Partitioner; // ********** FUTURE PRODUCER ********** // +///Default implementation for FutureProducerContext::delivery +pub fn future_delivery_impl( + delivery_result: &DeliveryResult<'_>, + tx: oneshot::Sender, +) { + let owned_delivery_result = match *delivery_result { + Ok(ref message) => Ok(Delivery { + partition: message.partition(), + offset: message.offset(), + timestamp: message.timestamp(), + }), + Err((ref error, ref message)) => Err((error.clone(), message.detach())), + }; + let _ = tx.send(owned_delivery_result); // TODO: handle error +} + /// A record for the future producer. /// /// Like [`BaseRecord`], but specific to the [`FutureProducer`]. The only @@ -186,20 +202,9 @@ where { type DeliveryOpaque = Box>; - fn delivery( - &self, - delivery_result: &DeliveryResult<'_>, - tx: Box>, - ) { - let owned_delivery_result = match *delivery_result { - Ok(ref message) => Ok(Delivery { - partition: message.partition(), - offset: message.offset(), - timestamp: message.timestamp(), - }), - Err((ref error, ref message)) => Err((error.clone(), message.detach())), - }; - let _ = tx.send(owned_delivery_result); // TODO: handle error + #[inline] + fn delivery(&self, delivery_result: &DeliveryResult<'_>, tx: Self::DeliveryOpaque) { + future_delivery_impl(delivery_result, *tx) } } @@ -215,18 +220,21 @@ where /// underlying producer. The internal polling thread will be terminated when the /// `FutureProducer` goes out of scope. #[must_use = "Producer polling thread will stop immediately if unused"] -pub struct FutureProducer -where +pub struct FutureProducer< + C = FutureProducerContext, + R = DefaultRuntime, + Part = NoCustomPartitioner, +> where Part: Partitioner, - C: ClientContext + 'static, + C: ProducerContext + 'static, { - producer: Arc, Part>>, + producer: Arc>, _runtime: PhantomData, } impl Clone for FutureProducer where - C: ClientContext + 'static, + C: ProducerContext + 'static, { fn clone(&self) -> FutureProducer { FutureProducer { @@ -236,28 +244,30 @@ where } } -impl FromClientConfig for FutureProducer +impl FromClientConfig for FutureProducer, R> where R: AsyncRuntime, { - fn from_config(config: &ClientConfig) -> KafkaResult> { - FutureProducer::from_config_and_context(config, DefaultClientContext) + fn from_config( + config: &ClientConfig, + ) -> KafkaResult, R>> { + let context = FutureProducerContext { + wrapped_context: DefaultClientContext, + }; + FutureProducer::from_config_and_context(config, context) } } impl FromClientConfigAndContext for FutureProducer where - C: ClientContext + 'static, + C: ProducerContext + 'static, R: AsyncRuntime, { fn from_config_and_context( config: &ClientConfig, context: C, ) -> KafkaResult> { - let future_context = FutureProducerContext { - wrapped_context: context, - }; - let threaded_producer = ThreadedProducer::from_config_and_context(config, future_context)?; + let threaded_producer = ThreadedProducer::from_config_and_context(config, context)?; Ok(FutureProducer { producer: Arc::new(threaded_producer), _runtime: PhantomData, @@ -283,9 +293,26 @@ impl Future for DeliveryFuture { } } -impl FutureProducer +/// Creates `FutureProducer` with customized `ProducerContext` +pub fn custom_future_producer< + P: Partitioner + Send + Sync + 'static, + C: ProducerContext>> + 'static, + R: AsyncRuntime, +>( + config: &ClientConfig, + context: C, +) -> KafkaResult> { + let threaded_producer = ThreadedProducer::::from_config_and_context(config, context)?; + Ok(FutureProducer { + producer: Arc::new(threaded_producer), + _runtime: PhantomData, + }) +} + +impl FutureProducer where - C: ClientContext + 'static, + Part: Partitioner, + C: ProducerContext>> + 'static, R: AsyncRuntime, { /// Sends a message to Kafka, returning the result of the send. @@ -385,13 +412,13 @@ where } } -impl Producer, Part> for FutureProducer +impl Producer for FutureProducer where - C: ClientContext + 'static, - R: AsyncRuntime, Part: Partitioner, + C: ProducerContext + 'static, + R: AsyncRuntime, { - fn client(&self) -> &Client> { + fn client(&self) -> &Client { self.producer.client() } diff --git a/tests/test_high_producers.rs b/tests/test_high_producers.rs index 9a71c9981..7f553d514 100644 --- a/tests/test_high_producers.rs +++ b/tests/test_high_producers.rs @@ -5,7 +5,6 @@ use std::time::{Duration, Instant}; use futures::stream::{FuturesUnordered, StreamExt}; -use rdkafka::client::DefaultClientContext; use rdkafka::config::ClientConfig; use rdkafka::error::{KafkaError, RDKafkaErrorCode}; use rdkafka::message::{Header, Headers, Message, OwnedHeaders}; @@ -17,7 +16,7 @@ use crate::utils::*; mod utils; -fn future_producer(config_overrides: HashMap<&str, &str>) -> FutureProducer { +fn future_producer(config_overrides: HashMap<&str, &str>) -> FutureProducer { let mut config = ClientConfig::new(); config .set("bootstrap.servers", "localhost") diff --git a/tests/utils.rs b/tests/utils.rs index 10ab34cf5..e84ea5a1f 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -14,7 +14,8 @@ use rdkafka::config::ClientConfig; use rdkafka::consumer::ConsumerContext; use rdkafka::error::KafkaResult; use rdkafka::message::ToBytes; -use rdkafka::producer::{FutureProducer, FutureRecord}; +use rdkafka::producer::future_producer::{future_delivery_impl, OwnedDeliveryResult}; +use rdkafka::producer::{FutureProducer, FutureRecord, ProducerContext}; use rdkafka::statistics::Statistics; use rdkafka::TopicPartitionList; @@ -74,6 +75,18 @@ impl ClientContext for ProducerTestContext { fn stats(&self, _: Statistics) {} // Don't print stats } +impl ProducerContext for ProducerTestContext { + type DeliveryOpaque = Box>; + + fn delivery( + &self, + delivery_result: &rdkafka::message::DeliveryResult<'_>, + delivery_opaque: Self::DeliveryOpaque, + ) { + future_delivery_impl(delivery_result, *delivery_opaque); + } +} + pub async fn create_topic(name: &str, partitions: i32) { let client: AdminClient<_> = consumer_config("create_topic", None).create().unwrap(); client