Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 54 additions & 48 deletions rocketmq-client/src/producer/producer_impl/default_mq_producer_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ use crate::producer::request_response_future::RequestResponseFuture;
use crate::producer::send_callback::SendMessageCallback;
use crate::producer::send_result::SendResult;
use crate::producer::send_status::SendStatus;
use crate::producer::transaction_listener::TransactionListener;
use crate::producer::transaction_listener::ArcTransactionListener;
use crate::producer::transaction_send_result::TransactionSendResult;
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -266,7 +266,7 @@ pub struct DefaultMQProducerImpl {
semaphore_async_send_num: Arc<Semaphore>,
semaphore_async_send_size: Arc<Semaphore>,
default_mqproducer_impl_inner: Option<WeakArcMut<DefaultMQProducerImpl>>,
transaction_listener: Option<Arc<Box<dyn TransactionListener>>>,
transaction_listener: Option<ArcTransactionListener>,
}

#[allow(unused_must_use)]
Expand Down Expand Up @@ -2190,7 +2190,7 @@ impl DefaultMQProducerImpl {
self.default_mqproducer_impl_inner = Some(default_mqproducer_impl_inner);
}

pub fn set_transaction_listener(&mut self, transaction_listener: Arc<Box<dyn TransactionListener>>) {
pub fn set_transaction_listener(&mut self, transaction_listener: ArcTransactionListener) {
self.transaction_listener = Some(transaction_listener);
}
}
Expand All @@ -2210,7 +2210,7 @@ impl MQProducerInner for DefaultMQProducerImpl {
true
}

fn get_check_listener(&self) -> Option<Arc<Box<dyn TransactionListener>>> {
fn get_check_listener(&self) -> Option<ArcTransactionListener> {
self.transaction_listener.clone()
}

Expand All @@ -2235,16 +2235,17 @@ impl MQProducerInner for DefaultMQProducerImpl {
let broker_addr = broker_addr.clone();
let group = self.producer_config.producer_group().clone();

// Spawn independent task without storing handle (matches Java's executor.submit behavior)
tokio::spawn(async move {
// Use spawn_blocking to avoid blocking Tokio worker threads (matches Java's ExecutorService
// behavior)
tokio::task::spawn_blocking(move || {
let mut unique_key = msg.property(&CheetahString::from_static_str(
MessageConst::PROPERTY_UNIQ_CLIENT_MESSAGE_ID_KEYIDX,
));
if unique_key.is_none() {
unique_key = Some(msg.msg_id.clone());
}

// Check local transaction state with exception handling
// Check local transaction state with exception handling (synchronous execution)
let transaction_state = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
transaction_listener.check_local_transaction(&msg)
})) {
Expand All @@ -2258,48 +2259,53 @@ impl MQProducerInner for DefaultMQProducerImpl {
LocalTransactionState::Unknown
}
};
let request_header = EndTransactionRequestHeader {
topic: check_request_header.topic.clone().unwrap_or_default(),
producer_group: CheetahString::from_string(
producer_impl_inner.producer_config.producer_group().to_string(),
),
tran_state_table_offset: check_request_header.commit_log_offset as u64,
commit_log_offset: check_request_header.commit_log_offset as u64,
commit_or_rollback: match transaction_state {
LocalTransactionState::CommitMessage => MessageSysFlag::TRANSACTION_COMMIT_TYPE,
LocalTransactionState::RollbackMessage => MessageSysFlag::TRANSACTION_ROLLBACK_TYPE,
LocalTransactionState::Unknown => MessageSysFlag::TRANSACTION_NOT_TYPE,
},
from_transaction_check: true,
msg_id: unique_key.clone().unwrap_or_default(),
transaction_id: check_request_header.transaction_id.clone(),
rpc_request_header: RpcRequestHeader {
broker_name: check_request_header.rpc_request_header.unwrap_or_default().broker_name,
..Default::default()
},
};
// Execute end transaction hook
producer_impl_inner.do_execute_end_transaction_hook(
&msg.message,
unique_key.as_ref().unwrap(),
&broker_addr,
transaction_state,
true,
);

// Send end transaction request with error handling
if let Err(e) = producer_impl_inner
.client_instance
.as_mut()
.unwrap()
.mq_client_api_impl
.as_mut()
.unwrap()
.end_transaction_oneway(&broker_addr, request_header, CheetahString::from_static_str(""), 3000)
.await
{
tracing::error!("endTransactionOneway exception: {:?}", e);
}
// Switch back to async context for network I/O
let handle = tokio::runtime::Handle::current();
handle.spawn(async move {
let request_header = EndTransactionRequestHeader {
topic: check_request_header.topic.clone().unwrap_or_default(),
producer_group: CheetahString::from_string(
producer_impl_inner.producer_config.producer_group().to_string(),
),
tran_state_table_offset: check_request_header.commit_log_offset as u64,
commit_log_offset: check_request_header.commit_log_offset as u64,
commit_or_rollback: match transaction_state {
LocalTransactionState::CommitMessage => MessageSysFlag::TRANSACTION_COMMIT_TYPE,
LocalTransactionState::RollbackMessage => MessageSysFlag::TRANSACTION_ROLLBACK_TYPE,
LocalTransactionState::Unknown => MessageSysFlag::TRANSACTION_NOT_TYPE,
},
from_transaction_check: true,
msg_id: unique_key.clone().unwrap_or_default(),
transaction_id: check_request_header.transaction_id.clone(),
rpc_request_header: RpcRequestHeader {
broker_name: check_request_header.rpc_request_header.unwrap_or_default().broker_name,
..Default::default()
},
};
// Execute end transaction hook
producer_impl_inner.do_execute_end_transaction_hook(
&msg.message,
unique_key.as_ref().unwrap(),
&broker_addr,
transaction_state,
true,
);

// Send end transaction request with error handling
if let Err(e) = producer_impl_inner
.client_instance
.as_mut()
.unwrap()
.mq_client_api_impl
.as_mut()
.unwrap()
.end_transaction_oneway(&broker_addr, request_header, CheetahString::from_static_str(""), 3000)
.await
{
tracing::error!("endTransactionOneway exception: {:?}", e);
}
});
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::collections::HashSet;
use std::sync::Arc;

use cheetah_string::CheetahString;
use rocketmq_common::common::message::message_ext::MessageExt;
Expand All @@ -22,14 +21,14 @@ use rocketmq_rust::ArcMut;

use crate::producer::producer_impl::default_mq_producer_impl::DefaultMQProducerImpl;
use crate::producer::producer_impl::topic_publish_info::TopicPublishInfo;
use crate::producer::transaction_listener::TransactionListener;
use crate::producer::transaction_listener::ArcTransactionListener;

pub trait MQProducerInner: Send + Sync + 'static {
fn get_publish_topic_list(&self) -> HashSet<CheetahString>;

fn is_publish_topic_need_update(&self, topic: &CheetahString) -> bool;

fn get_check_listener(&self) -> Option<Arc<Box<dyn TransactionListener>>>;
fn get_check_listener(&self) -> Option<ArcTransactionListener>;

fn check_transaction_state(
&self,
Expand Down Expand Up @@ -63,7 +62,7 @@ impl MQProducerInnerImpl {
false
}

pub fn get_check_listener(&self) -> Option<Arc<Box<dyn TransactionListener>>> {
pub fn get_check_listener(&self) -> Option<ArcTransactionListener> {
if let Some(default_mqproducer_impl_inner) = &self.default_mqproducer_impl_inner {
return default_mqproducer_impl_inner.get_check_listener();
}
Expand Down
112 changes: 112 additions & 0 deletions rocketmq-client/src/producer/transaction_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,130 @@
// limitations under the License.

use std::any::Any;
use std::sync::Arc;

use rocketmq_common::common::message::message_ext::MessageExt;
use rocketmq_common::common::message::MessageTrait;

use crate::producer::local_transaction_state::LocalTransactionState;

/// Listener for handling transactional message operations.
///
/// This trait defines the callback interface for transactional message processing,
/// allowing applications to integrate local transaction execution with distributed
/// message transactions.
///
/// Implementations must be thread-safe as they may be invoked concurrently from
/// different threads or async tasks.
///
/// # Examples
///
/// ```ignore
/// use rocketmq_client_rust::producer::transaction_listener::TransactionListener;
/// use rocketmq_client_rust::producer::local_transaction_state::LocalTransactionState;
/// use rocketmq_common::common::message::MessageTrait;
/// use rocketmq_common::common::message::message_ext::MessageExt;
/// use std::any::Any;
///
/// struct OrderTransactionListener;
///
/// impl TransactionListener for OrderTransactionListener {
/// fn execute_local_transaction(
/// &self,
/// msg: &dyn MessageTrait,
/// arg: Option<&(dyn Any + Send + Sync)>,
/// ) -> LocalTransactionState {
/// // Execute local database transaction
/// if insert_order_into_db(msg).is_ok() {
/// LocalTransactionState::CommitMessage
/// } else {
/// LocalTransactionState::RollbackMessage
/// }
/// }
///
/// fn check_local_transaction(&self, msg: &MessageExt) -> LocalTransactionState {
/// // Check transaction status from database
/// if order_exists_in_db(msg) {
/// LocalTransactionState::CommitMessage
/// } else {
/// LocalTransactionState::Unknown
/// }
/// }
/// }
/// ```
pub trait TransactionListener: Send + Sync + 'static {
/// Executes the local transaction when sending a transactional message.
///
/// This method is invoked after the half message is successfully sent to the broker.
/// The implementation should execute the local business transaction and return
/// the transaction state to determine whether the message should be committed or rolled back.
///
/// # Parameters
///
/// * `msg` - The message being sent
/// * `arg` - Optional user-defined argument passed from the send operation
///
/// # Returns
///
/// The local transaction state indicating whether to commit, rollback, or defer the decision:
/// - `CommitMessage` - Commit the transaction and make the message visible to consumers
/// - `RollbackMessage` - Roll back the transaction and discard the message
/// - `Unknown` - Transaction state is uncertain, broker will check later
fn execute_local_transaction(
&self,
msg: &dyn MessageTrait,
arg: Option<&(dyn Any + Send + Sync)>,
) -> LocalTransactionState;

/// Checks the status of a previously executed local transaction.
///
/// This method is invoked by the broker when it needs to verify the state of a
/// transaction whose initial state was `Unknown` or when the transaction check
/// timeout is reached.
///
/// The implementation should query the local transaction state (e.g., from a database)
/// and return the current status.
///
/// # Parameters
///
/// * `msg` - The message whose transaction status needs to be checked
///
/// # Returns
///
/// The current state of the local transaction:
/// - `CommitMessage` - The local transaction was committed successfully
/// - `RollbackMessage` - The local transaction failed or was rolled back
/// - `Unknown` - Transaction state cannot be determined at this time
fn check_local_transaction(&self, msg: &MessageExt) -> LocalTransactionState;
}

/// Thread-safe shared reference to a [`TransactionListener`].
///
/// This type alias provides a convenient way to share transaction listeners across
/// threads using atomic reference counting. It uses `Arc<dyn TransactionListener>`
/// instead of `Arc<Box<dyn TransactionListener>>` to avoid double heap allocation
/// and minimize pointer indirection overhead.
///
/// # Performance
///
/// Using this alias instead of `Arc<Box<dyn TransactionListener>>` provides:
/// - One fewer heap allocation per instance
/// - One fewer pointer dereference per method call
/// - Reduced memory overhead (saves approximately 16 bytes per instance)
///
/// # Examples
///
/// ```ignore
/// use rocketmq_client_rust::producer::transaction_listener::{TransactionListener, ArcTransactionListener};
/// use std::sync::Arc;
///
/// struct MyListener;
/// impl TransactionListener for MyListener { /* ... */ }
///
/// // Create a shared reference
/// let listener: ArcTransactionListener = Arc::new(MyListener);
///
/// // Clone for sharing across threads
/// let listener_clone = listener.clone();
/// ```
pub type ArcTransactionListener = Arc<dyn TransactionListener>;
43 changes: 38 additions & 5 deletions rocketmq-client/src/producer/transaction_mq_produce_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::base::client_config::ClientConfig;
use crate::producer::default_mq_producer::DefaultMQProducer;
use crate::producer::produce_accumulator::ProduceAccumulator;
use crate::producer::producer_impl::default_mq_producer_impl::DefaultMQProducerImpl;
use crate::producer::transaction_listener::ArcTransactionListener;
use crate::producer::transaction_listener::TransactionListener;
use crate::producer::transaction_mq_producer::TransactionMQProducer;
use crate::producer::transaction_mq_producer::TransactionProducerConfig;
Expand Down Expand Up @@ -55,7 +56,10 @@ pub struct TransactionMQProducerBuilder {
compress_level: Option<i32>,
compress_type: Option<CompressionType>,
compressor: Option<&'static (dyn Compressor + Send + Sync)>,
transaction_listener: Option<Arc<Box<dyn TransactionListener>>>,
transaction_listener: Option<ArcTransactionListener>,
check_thread_pool_min_size: Option<u32>,
check_thread_pool_max_size: Option<u32>,
check_request_hold_max: Option<u32>,
check_runtime: Option<Arc<RocketMQRuntime>>,
}

Expand Down Expand Up @@ -86,6 +90,9 @@ impl TransactionMQProducerBuilder {
compress_type: None,
compressor: None,
transaction_listener: None,
check_thread_pool_min_size: None,
check_thread_pool_max_size: None,
check_request_hold_max: None,
check_runtime: None,
}
}
Expand Down Expand Up @@ -216,7 +223,33 @@ impl TransactionMQProducerBuilder {
}

pub fn transaction_listener(mut self, transaction_listener: impl TransactionListener) -> Self {
self.transaction_listener = Some(Arc::new(Box::new(transaction_listener)));
self.transaction_listener = Some(Arc::new(transaction_listener));
self
}

/// Set maximum size of transaction check thread pool
///
/// Note: When using default Tokio Runtime with spawn_blocking, this serves as a reference.
/// To control actual thread count, configure the Runtime:
/// ```ignore
/// tokio::runtime::Builder::new_multi_thread()
/// .max_blocking_threads(100)
/// .build()
/// ```
pub fn check_thread_pool_max_size(mut self, size: u32) -> Self {
self.check_thread_pool_max_size = Some(size);
self
}

/// Set minimum size of transaction check thread pool
pub fn check_thread_pool_min_size(mut self, size: u32) -> Self {
self.check_thread_pool_min_size = Some(size);
self
}

/// Set maximum capacity of transaction check request queue
pub fn check_request_hold_max(mut self, size: u32) -> Self {
self.check_request_hold_max = Some(size);
self
}

Expand Down Expand Up @@ -304,9 +337,9 @@ impl TransactionMQProducerBuilder {
}
let transaction_producer_config = TransactionProducerConfig {
transaction_listener: self.transaction_listener,
check_thread_pool_min_size: 0,
check_thread_pool_max_size: 0,
check_request_hold_max: 0,
check_thread_pool_min_size: self.check_thread_pool_min_size.unwrap_or(1),
check_thread_pool_max_size: self.check_thread_pool_max_size.unwrap_or(1),
check_request_hold_max: self.check_request_hold_max.unwrap_or(2000),
};
TransactionMQProducer::new(transaction_producer_config, mq_producer)
}
Expand Down
Loading
Loading