@@ -110,6 +110,8 @@ pub struct ClientSession {
110
110
pub ( crate ) transaction : Transaction ,
111
111
pub ( crate ) snapshot_time : Option < Timestamp > ,
112
112
pub ( crate ) operation_time : Option < Timestamp > ,
113
+ #[ cfg( test) ]
114
+ pub ( crate ) convenient_transaction_timeout : Option < Duration > ,
113
115
}
114
116
115
117
#[ derive( Debug ) ]
@@ -216,6 +218,8 @@ impl ClientSession {
216
218
transaction : Default :: default ( ) ,
217
219
snapshot_time : None ,
218
220
operation_time : None ,
221
+ #[ cfg( test) ]
222
+ convenient_transaction_timeout : None ,
219
223
}
220
224
}
221
225
@@ -561,13 +565,117 @@ impl ClientSession {
561
565
}
562
566
}
563
567
568
+ /// Starts a transaction, runs the given callback, and commits or aborts the transaction.
569
+ /// Transient transaction errors will cause the callback or the commit to be retried;
570
+ /// other errors will cause the transaction to be aborted and the error returned to the
571
+ /// caller. If the callback needs to provide its own error information, the
572
+ /// [`Error::custom`](crate::error::Error::custom) method can accept an arbitrary payload that
573
+ /// can be retrieved via [`Error::get_custom`](crate::error::Error::get_custom).
574
+ ///
575
+ /// Because the callback can be repeatedly executed and because it returns a future, the rust
576
+ /// closure borrowing rules for captured values can be overly restrictive. As a
577
+ /// convenience, `with_transaction` accepts a context argument that will be passed to the
578
+ /// callback along with the session:
579
+ ///
580
+ /// ```no_run
581
+ /// # use mongodb::{bson::{doc, Document}, error::Result, Client};
582
+ /// # use futures::FutureExt;
583
+ /// # async fn wrapper() -> Result<()> {
584
+ /// # let client = Client::with_uri_str("mongodb://example.com").await?;
585
+ /// # let mut session = client.start_session(None).await?;
586
+ /// let coll = client.database("mydb").collection::<Document>("mycoll");
587
+ /// let my_data = "my data".to_string();
588
+ /// // This works:
589
+ /// session.with_transaction(
590
+ /// (&coll, &my_data),
591
+ /// |session, (coll, my_data)| async move {
592
+ /// coll.insert_one_with_session(doc! { "data": *my_data }, None, session).await
593
+ /// }.boxed(),
594
+ /// None,
595
+ /// ).await?;
596
+ /// /* This will not compile with a "variable moved due to use in generator" error:
597
+ /// session.with_transaction(
598
+ /// (),
599
+ /// |session, _| async move {
600
+ /// coll.insert_one_with_session(doc! { "data": my_data }, None, session).await
601
+ /// }.boxed(),
602
+ /// None,
603
+ /// ).await?;
604
+ /// */
605
+ /// # Ok(())
606
+ /// # }
607
+ /// ```
608
+ pub async fn with_transaction < R , C , F > (
609
+ & mut self ,
610
+ mut context : C ,
611
+ mut callback : F ,
612
+ options : impl Into < Option < TransactionOptions > > ,
613
+ ) -> Result < R >
614
+ where
615
+ F : for < ' a > FnMut ( & ' a mut ClientSession , & ' a mut C ) -> BoxFuture < ' a , Result < R > > ,
616
+ {
617
+ let options = options. into ( ) ;
618
+ let timeout = Duration :: from_secs ( 120 ) ;
619
+ #[ cfg( test) ]
620
+ let timeout = self . convenient_transaction_timeout . unwrap_or ( timeout) ;
621
+ let start = Instant :: now ( ) ;
622
+
623
+ use crate :: error:: { TRANSIENT_TRANSACTION_ERROR , UNKNOWN_TRANSACTION_COMMIT_RESULT } ;
624
+
625
+ ' transaction: loop {
626
+ self . start_transaction ( options. clone ( ) ) . await ?;
627
+ let ret = match callback ( self , & mut context) . await {
628
+ Ok ( v) => v,
629
+ Err ( e) => {
630
+ if matches ! (
631
+ self . transaction. state,
632
+ TransactionState :: Starting | TransactionState :: InProgress
633
+ ) {
634
+ self . abort_transaction ( ) . await ?;
635
+ }
636
+ if e. contains_label ( TRANSIENT_TRANSACTION_ERROR ) && start. elapsed ( ) < timeout {
637
+ continue ' transaction;
638
+ }
639
+ return Err ( e) ;
640
+ }
641
+ } ;
642
+ if matches ! (
643
+ self . transaction. state,
644
+ TransactionState :: None
645
+ | TransactionState :: Aborted
646
+ | TransactionState :: Committed { .. }
647
+ ) {
648
+ return Ok ( ret) ;
649
+ }
650
+ ' commit: loop {
651
+ match self . commit_transaction ( ) . await {
652
+ Ok ( ( ) ) => return Ok ( ret) ,
653
+ Err ( e) => {
654
+ if e. is_max_time_ms_expired_error ( ) || start. elapsed ( ) >= timeout {
655
+ return Err ( e) ;
656
+ }
657
+ if e. contains_label ( UNKNOWN_TRANSACTION_COMMIT_RESULT ) {
658
+ continue ' commit;
659
+ }
660
+ if e. contains_label ( TRANSIENT_TRANSACTION_ERROR ) {
661
+ continue ' transaction;
662
+ }
663
+ return Err ( e) ;
664
+ }
665
+ }
666
+ }
667
+ }
668
+ }
669
+
564
670
fn default_transaction_options ( & self ) -> Option < & TransactionOptions > {
565
671
self . options
566
672
. as_ref ( )
567
673
. and_then ( |options| options. default_transaction_options . as_ref ( ) )
568
674
}
569
675
}
570
676
677
+ pub type BoxFuture < ' a , T > = std:: pin:: Pin < Box < dyn std:: future:: Future < Output = T > + Send + ' a > > ;
678
+
571
679
struct DroppedClientSession {
572
680
cluster_time : Option < ClusterTime > ,
573
681
server_session : ServerSession ,
@@ -590,6 +698,8 @@ impl From<DroppedClientSession> for ClientSession {
590
698
transaction : dropped_session. transaction ,
591
699
snapshot_time : dropped_session. snapshot_time ,
592
700
operation_time : dropped_session. operation_time ,
701
+ #[ cfg( test) ]
702
+ convenient_transaction_timeout : None ,
593
703
}
594
704
}
595
705
}
0 commit comments