Skip to content

Commit c2eb321

Browse files
authored
RUST-107 Convenient transactions (#849)
1 parent 204cbb6 commit c2eb321

27 files changed

+6780
-247
lines changed

.evergreen/MSRV-Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/client/session/mod.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ pub struct ClientSession {
110110
pub(crate) transaction: Transaction,
111111
pub(crate) snapshot_time: Option<Timestamp>,
112112
pub(crate) operation_time: Option<Timestamp>,
113+
#[cfg(test)]
114+
pub(crate) convenient_transaction_timeout: Option<Duration>,
113115
}
114116

115117
#[derive(Debug)]
@@ -216,6 +218,8 @@ impl ClientSession {
216218
transaction: Default::default(),
217219
snapshot_time: None,
218220
operation_time: None,
221+
#[cfg(test)]
222+
convenient_transaction_timeout: None,
219223
}
220224
}
221225

@@ -561,13 +565,117 @@ impl ClientSession {
561565
}
562566
}
563567

568+
/// Starts a transaction, runs the given callback, and commits or aborts the transaction.
569+
/// Transient transaction errors will cause the callback or the commit to be retried;
570+
/// other errors will cause the transaction to be aborted and the error returned to the
571+
/// caller. If the callback needs to provide its own error information, the
572+
/// [`Error::custom`](crate::error::Error::custom) method can accept an arbitrary payload that
573+
/// can be retrieved via [`Error::get_custom`](crate::error::Error::get_custom).
574+
///
575+
/// Because the callback can be repeatedly executed and because it returns a future, the rust
576+
/// closure borrowing rules for captured values can be overly restrictive. As a
577+
/// convenience, `with_transaction` accepts a context argument that will be passed to the
578+
/// callback along with the session:
579+
///
580+
/// ```no_run
581+
/// # use mongodb::{bson::{doc, Document}, error::Result, Client};
582+
/// # use futures::FutureExt;
583+
/// # async fn wrapper() -> Result<()> {
584+
/// # let client = Client::with_uri_str("mongodb://example.com").await?;
585+
/// # let mut session = client.start_session(None).await?;
586+
/// let coll = client.database("mydb").collection::<Document>("mycoll");
587+
/// let my_data = "my data".to_string();
588+
/// // This works:
589+
/// session.with_transaction(
590+
/// (&coll, &my_data),
591+
/// |session, (coll, my_data)| async move {
592+
/// coll.insert_one_with_session(doc! { "data": *my_data }, None, session).await
593+
/// }.boxed(),
594+
/// None,
595+
/// ).await?;
596+
/// /* This will not compile with a "variable moved due to use in generator" error:
597+
/// session.with_transaction(
598+
/// (),
599+
/// |session, _| async move {
600+
/// coll.insert_one_with_session(doc! { "data": my_data }, None, session).await
601+
/// }.boxed(),
602+
/// None,
603+
/// ).await?;
604+
/// */
605+
/// # Ok(())
606+
/// # }
607+
/// ```
608+
pub async fn with_transaction<R, C, F>(
609+
&mut self,
610+
mut context: C,
611+
mut callback: F,
612+
options: impl Into<Option<TransactionOptions>>,
613+
) -> Result<R>
614+
where
615+
F: for<'a> FnMut(&'a mut ClientSession, &'a mut C) -> BoxFuture<'a, Result<R>>,
616+
{
617+
let options = options.into();
618+
let timeout = Duration::from_secs(120);
619+
#[cfg(test)]
620+
let timeout = self.convenient_transaction_timeout.unwrap_or(timeout);
621+
let start = Instant::now();
622+
623+
use crate::error::{TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT};
624+
625+
'transaction: loop {
626+
self.start_transaction(options.clone()).await?;
627+
let ret = match callback(self, &mut context).await {
628+
Ok(v) => v,
629+
Err(e) => {
630+
if matches!(
631+
self.transaction.state,
632+
TransactionState::Starting | TransactionState::InProgress
633+
) {
634+
self.abort_transaction().await?;
635+
}
636+
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) && start.elapsed() < timeout {
637+
continue 'transaction;
638+
}
639+
return Err(e);
640+
}
641+
};
642+
if matches!(
643+
self.transaction.state,
644+
TransactionState::None
645+
| TransactionState::Aborted
646+
| TransactionState::Committed { .. }
647+
) {
648+
return Ok(ret);
649+
}
650+
'commit: loop {
651+
match self.commit_transaction().await {
652+
Ok(()) => return Ok(ret),
653+
Err(e) => {
654+
if e.is_max_time_ms_expired_error() || start.elapsed() >= timeout {
655+
return Err(e);
656+
}
657+
if e.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
658+
continue 'commit;
659+
}
660+
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) {
661+
continue 'transaction;
662+
}
663+
return Err(e);
664+
}
665+
}
666+
}
667+
}
668+
}
669+
564670
fn default_transaction_options(&self) -> Option<&TransactionOptions> {
565671
self.options
566672
.as_ref()
567673
.and_then(|options| options.default_transaction_options.as_ref())
568674
}
569675
}
570676

677+
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
678+
571679
struct DroppedClientSession {
572680
cluster_time: Option<ClusterTime>,
573681
server_session: ServerSession,
@@ -590,6 +698,8 @@ impl From<DroppedClientSession> for ClientSession {
590698
transaction: dropped_session.transaction,
591699
snapshot_time: dropped_session.snapshot_time,
592700
operation_time: dropped_session.operation_time,
701+
#[cfg(test)]
702+
convenient_transaction_timeout: None,
593703
}
594704
}
595705
}

src/error.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Contains the `Error` and `Result` types that `mongodb` uses.
22
33
use std::{
4+
any::Any,
45
collections::{HashMap, HashSet},
56
fmt::{self, Debug},
67
sync::Arc,
@@ -52,6 +53,22 @@ pub struct Error {
5253
}
5354

5455
impl Error {
56+
/// Create a new `Error` wrapping an arbitrary value. Can be used to abort transactions in
57+
/// callbacks for [`ClientSession::with_transaction`](crate::ClientSession::with_transaction).
58+
pub fn custom(e: impl Any + Send + Sync) -> Self {
59+
Self::new(ErrorKind::Custom(Arc::new(e)), None::<Option<String>>)
60+
}
61+
62+
/// Retrieve a reference to a value provided to `Error::custom`. Returns `None` if this is not
63+
/// a custom error or if the payload types mismatch.
64+
pub fn get_custom<E: Any>(&self) -> Option<&E> {
65+
if let ErrorKind::Custom(c) = &*self.kind {
66+
c.downcast_ref()
67+
} else {
68+
None
69+
}
70+
}
71+
5572
pub(crate) fn new(kind: ErrorKind, labels: Option<impl IntoIterator<Item = String>>) -> Self {
5673
let mut labels: HashSet<String> = labels
5774
.map(|labels| labels.into_iter().collect())
@@ -140,6 +157,10 @@ impl Error {
140157
matches!(self.kind.as_ref(), ErrorKind::ServerSelection { .. })
141158
}
142159

160+
pub(crate) fn is_max_time_ms_expired_error(&self) -> bool {
161+
self.code() == Some(50)
162+
}
163+
143164
/// Whether a read operation should be retried if this error occurs.
144165
pub(crate) fn is_read_retryable(&self) -> bool {
145166
if self.is_network_error() {
@@ -423,6 +444,7 @@ impl Error {
423444
| ErrorKind::IncompatibleServer { .. }
424445
| ErrorKind::MissingResumeToken
425446
| ErrorKind::Authentication { .. }
447+
| ErrorKind::Custom(_)
426448
| ErrorKind::GridFs(_) => {}
427449
#[cfg(feature = "in-use-encryption-unstable")]
428450
ErrorKind::Encryption(_) => {}
@@ -578,6 +600,10 @@ pub enum ErrorKind {
578600
#[cfg(feature = "in-use-encryption-unstable")]
579601
#[error("An error occurred during client-side encryption: {0}")]
580602
Encryption(mongocrypt::error::Error),
603+
604+
/// A custom value produced by user code.
605+
#[error("Custom user error")]
606+
Custom(Arc<dyn Any + Send + Sync>),
581607
}
582608

583609
impl ErrorKind {

src/sync/client/session.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,72 @@ impl ClientSession {
135135
pub fn abort_transaction(&mut self) -> Result<()> {
136136
runtime::block_on(self.async_client_session.abort_transaction())
137137
}
138+
139+
/// Starts a transaction, runs the given callback, and commits or aborts the transaction.
140+
/// Transient transaction errors will cause the callback or the commit to be retried;
141+
/// other errors will cause the transaction to be aborted and the error returned to the
142+
/// caller. If the callback needs to provide its own error information, the
143+
/// [`Error::custom`](crate::error::Error::custom) method can accept an arbitrary payload that
144+
/// can be retrieved via [`Error::get_custom`](crate::error::Error::get_custom).
145+
pub fn with_transaction<R, F>(
146+
&mut self,
147+
mut callback: F,
148+
options: impl Into<Option<TransactionOptions>>,
149+
) -> Result<R>
150+
where
151+
F: for<'a> FnMut(&'a mut ClientSession) -> Result<R>,
152+
{
153+
let options = options.into();
154+
let timeout = std::time::Duration::from_secs(120);
155+
let start = std::time::Instant::now();
156+
157+
use crate::{
158+
client::session::TransactionState,
159+
error::{TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT},
160+
};
161+
162+
'transaction: loop {
163+
self.start_transaction(options.clone())?;
164+
let ret = match callback(self) {
165+
Ok(v) => v,
166+
Err(e) => {
167+
if matches!(
168+
self.async_client_session.transaction.state,
169+
TransactionState::Starting | TransactionState::InProgress
170+
) {
171+
self.abort_transaction()?;
172+
}
173+
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) && start.elapsed() < timeout {
174+
continue 'transaction;
175+
}
176+
return Err(e);
177+
}
178+
};
179+
if matches!(
180+
self.async_client_session.transaction.state,
181+
TransactionState::None
182+
| TransactionState::Aborted
183+
| TransactionState::Committed { .. }
184+
) {
185+
return Ok(ret);
186+
}
187+
'commit: loop {
188+
match self.commit_transaction() {
189+
Ok(()) => return Ok(ret),
190+
Err(e) => {
191+
if e.is_max_time_ms_expired_error() || start.elapsed() >= timeout {
192+
return Err(e);
193+
}
194+
if e.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
195+
continue 'commit;
196+
}
197+
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) {
198+
continue 'transaction;
199+
}
200+
return Err(e);
201+
}
202+
}
203+
}
204+
}
205+
}
138206
}

0 commit comments

Comments
 (0)