@@ -299,6 +299,7 @@ impl Client {
299
299
}
300
300
301
301
let stream_description = connection. stream_description ( ) ?;
302
+ let is_sharded = stream_description. initial_server_type == ServerType :: Mongos ;
302
303
let mut cmd = op. build ( stream_description) ?;
303
304
self . inner
304
305
. topology
@@ -337,15 +338,22 @@ impl Client {
337
338
cmd. set_start_transaction ( ) ;
338
339
cmd. set_autocommit ( ) ;
339
340
cmd. set_txn_read_concern ( * session) ?;
340
- if stream_description . initial_server_type == ServerType :: Mongos {
341
+ if is_sharded {
341
342
session. pin_mongos ( connection. address ( ) . clone ( ) ) ;
342
343
}
343
344
session. transaction . state = TransactionState :: InProgress ;
344
345
}
345
- TransactionState :: InProgress
346
- | TransactionState :: Committed { .. }
347
- | TransactionState :: Aborted => {
346
+ TransactionState :: InProgress => cmd. set_autocommit ( ) ,
347
+ TransactionState :: Committed { .. } | TransactionState :: Aborted => {
348
348
cmd. set_autocommit ( ) ;
349
+
350
+ // Append the recovery token to the command if we are committing or aborting
351
+ // on a sharded transaction.
352
+ if is_sharded {
353
+ if let Some ( ref recovery_token) = session. transaction . recovery_token {
354
+ cmd. set_recovery_token ( recovery_token) ;
355
+ }
356
+ }
349
357
}
350
358
_ => { }
351
359
}
@@ -414,6 +422,9 @@ impl Client {
414
422
Ok ( r) => {
415
423
self . update_cluster_time ( & r, session) . await ;
416
424
if r. is_success ( ) {
425
+ // Retrieve recovery token from successful response.
426
+ Client :: update_recovery_token ( is_sharded, & r, session) . await ;
427
+
417
428
Ok ( CommandResult {
418
429
raw : response,
419
430
deserialized : r. into_body ( ) ,
@@ -458,7 +469,15 @@ impl Client {
458
469
} ) )
459
470
}
460
471
// for ok: 1 just return the original deserialization error.
461
- _ => Err ( deserialize_error) ,
472
+ _ => {
473
+ Client :: update_recovery_token (
474
+ is_sharded,
475
+ & error_response,
476
+ session,
477
+ )
478
+ . await ;
479
+ Err ( deserialize_error)
480
+ }
462
481
}
463
482
}
464
483
// We failed to deserialize even that, so just return the original
@@ -635,6 +654,18 @@ impl Client {
635
654
}
636
655
}
637
656
}
657
+
658
+ async fn update_recovery_token < T : Response > (
659
+ is_sharded : bool ,
660
+ response : & T ,
661
+ session : & mut Option < & mut ClientSession > ,
662
+ ) {
663
+ if let Some ( ref mut session) = session {
664
+ if is_sharded && session. in_transaction ( ) {
665
+ session. transaction . recovery_token = response. recovery_token ( ) . cloned ( ) ;
666
+ }
667
+ }
668
+ }
638
669
}
639
670
640
671
impl Error {
0 commit comments