Skip to content

Commit ab491ae

Browse files
authored
RUST-1274 Fix commitTransaction on check out retries (#651)
This also fixes RUST-1317.
1 parent cb45c29 commit ab491ae

File tree

8 files changed

+207
-27
lines changed

8 files changed

+207
-27
lines changed

src/client/executor.rs

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -351,27 +351,15 @@ impl Client {
351351

352352
let retryability = self.get_retryability(&conn, &op, &session)?;
353353

354-
let txn_number = match session {
355-
Some(ref mut session) => {
356-
if session.transaction.state != TransactionState::None {
357-
Some(session.txn_number())
358-
} else {
359-
match retryability {
360-
Retryability::Write => Some(session.get_and_increment_txn_number()),
361-
_ => None,
362-
}
363-
}
364-
}
365-
None => None,
366-
};
354+
let txn_number = get_txn_number(&mut session, retryability);
367355

368356
match self
369357
.execute_operation_on_connection(
370358
&mut op,
371359
&mut conn,
372360
&mut session,
373361
txn_number,
374-
&retryability,
362+
retryability,
375363
)
376364
.await
377365
{
@@ -424,7 +412,7 @@ impl Client {
424412
&self,
425413
op: &mut T,
426414
session: &mut Option<&mut ClientSession>,
427-
txn_number: Option<i64>,
415+
prior_txn_number: Option<i64>,
428416
first_error: Error,
429417
) -> Result<ExecutionOutput<T>> {
430418
op.update_for_retry();
@@ -446,8 +434,10 @@ impl Client {
446434
return Err(first_error);
447435
}
448436

437+
let txn_number = prior_txn_number.or_else(|| get_txn_number(session, retryability));
438+
449439
match self
450-
.execute_operation_on_connection(op, &mut conn, session, txn_number, &retryability)
440+
.execute_operation_on_connection(op, &mut conn, session, txn_number, retryability)
451441
.await
452442
{
453443
Ok(operation_output) => Ok(ExecutionOutput {
@@ -481,7 +471,7 @@ impl Client {
481471
connection: &mut Connection,
482472
session: &mut Option<&mut ClientSession>,
483473
txn_number: Option<i64>,
484-
retryability: &Retryability,
474+
retryability: Retryability,
485475
) -> Result<T::O> {
486476
if let Some(wc) = op.write_concern() {
487477
wc.validate()?;
@@ -918,6 +908,25 @@ async fn get_connection<T: Operation>(
918908
}
919909
}
920910

911+
fn get_txn_number(
912+
session: &mut Option<&mut ClientSession>,
913+
retryability: Retryability,
914+
) -> Option<i64> {
915+
match session {
916+
Some(ref mut session) => {
917+
if session.transaction.state != TransactionState::None {
918+
Some(session.txn_number())
919+
} else {
920+
match retryability {
921+
Retryability::Write => Some(session.get_and_increment_txn_number()),
922+
_ => None,
923+
}
924+
}
925+
}
926+
None => None,
927+
}
928+
}
929+
921930
impl Error {
922931
/// Adds the necessary labels to this Error, and unpins the session if needed.
923932
///
@@ -936,7 +945,7 @@ impl Error {
936945
&mut self,
937946
conn: Option<&Connection>,
938947
session: &mut Option<&mut ClientSession>,
939-
retryability: Option<&Retryability>,
948+
retryability: Option<Retryability>,
940949
) -> Result<()> {
941950
let transaction_state = session.as_ref().map_or(&TransactionState::None, |session| {
942951
&session.transaction.state
@@ -970,7 +979,7 @@ impl Error {
970979
}
971980
}
972981
TransactionState::None => {
973-
if retryability == Some(&Retryability::Write) {
982+
if retryability == Some(Retryability::Write) {
974983
if let Some(max_wire_version) = max_wire_version {
975984
if self.should_add_retryable_write_label(max_wire_version) {
976985
self.add_label(RETRYABLE_WRITE_ERROR);

src/event/sdam/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ pub struct ServerDescriptionChangedEvent {
3636
pub new_description: ServerDescription,
3737
}
3838

39+
impl ServerDescriptionChangedEvent {
40+
#[cfg(test)]
41+
pub(crate) fn is_marked_unknown_event(&self) -> bool {
42+
self.previous_description
43+
.description
44+
.server_type
45+
.is_available()
46+
&& self.new_description.description.server_type == crate::ServerType::Unknown
47+
}
48+
}
49+
3950
/// Published when a server is initialized.
4051
#[derive(Clone, Debug, Deserialize, PartialEq)]
4152
#[non_exhaustive]

src/operation/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ where
379379
}
380380
}
381381

382-
#[derive(Debug, PartialEq)]
382+
#[derive(Copy, Clone, Debug, PartialEq)]
383383
pub(crate) enum Retryability {
384384
Write,
385385
Read,

src/sdam/description/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ impl PartialEq for ServerDescription {
122122

123123
self_response == other_response
124124
}
125+
(Err(self_err), Err(other_err)) => self_err == other_err,
125126
_ => false,
126127
}
127128
}

src/sdam/monitor.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,8 @@ impl HeartbeatMonitor {
127127
let mut topology_check_requests_subscriber =
128128
topology.subscribe_to_topology_check_requests();
129129

130-
if self.check_server(&topology, &server).await {
131-
topology.notify_topology_changed();
132-
}
130+
self.check_server(&topology, &server).await;
131+
topology.notify_topology_changed();
133132

134133
// drop strong reference to topology before going back to sleep in case it drops off
135134
// in between checks.

src/test/client.rs

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{borrow::Cow, collections::HashMap, time::Duration};
1+
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
22

33
use bson::Document;
44
use serde::Deserialize;
@@ -7,11 +7,25 @@ use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
77
use crate::{
88
bson::{doc, Bson},
99
error::{CommandError, Error, ErrorKind},
10+
hello::LEGACY_HELLO_COMMAND_NAME,
1011
options::{AuthMechanism, ClientOptions, Credential, ListDatabasesOptions, ServerAddress},
1112
runtime,
1213
selection_criteria::{ReadPreference, ReadPreferenceOptions, SelectionCriteria},
13-
test::{log_uncaptured, util::TestClient, CLIENT_OPTIONS, LOCK},
14+
test::{
15+
log_uncaptured,
16+
util::TestClient,
17+
CmapEvent,
18+
Event,
19+
EventHandler,
20+
FailCommandOptions,
21+
FailPoint,
22+
FailPointMode,
23+
SdamEvent,
24+
CLIENT_OPTIONS,
25+
LOCK,
26+
},
1427
Client,
28+
ServerType,
1529
};
1630

1731
#[derive(Debug, Deserialize)]
@@ -663,3 +677,134 @@ async fn plain_auth() {
663677
}
664678
);
665679
}
680+
681+
/// Test verifies that retrying a commitTransaction operation after a checkOut
682+
/// failure works.
683+
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
684+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
685+
async fn retry_commit_txn_check_out() {
686+
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;
687+
688+
let setup_client = TestClient::new().await;
689+
if !setup_client.is_replica_set() {
690+
log_uncaptured("skipping retry_commit_txn_check_out due to non-replicaset topology");
691+
return;
692+
}
693+
694+
if !setup_client.supports_transactions() {
695+
log_uncaptured("skipping retry_commit_txn_check_out due to lack of transaction support");
696+
return;
697+
}
698+
699+
if !setup_client.supports_fail_command_appname_initial_handshake() {
700+
log_uncaptured(
701+
"skipping retry_commit_txn_check_out due to insufficient failCommand support",
702+
);
703+
return;
704+
}
705+
706+
// ensure namespace exists
707+
setup_client
708+
.database("retry_commit_txn_check_out")
709+
.collection("retry_commit_txn_check_out")
710+
.insert_one(doc! {}, None)
711+
.await
712+
.unwrap();
713+
714+
let mut options = CLIENT_OPTIONS.clone();
715+
let handler = Arc::new(EventHandler::new());
716+
options.cmap_event_handler = Some(handler.clone());
717+
options.sdam_event_handler = Some(handler.clone());
718+
options.heartbeat_freq = Some(Duration::from_secs(120));
719+
options.app_name = Some("retry_commit_txn_check_out".to_string());
720+
let client = Client::with_options(options).unwrap();
721+
722+
let mut session = client.start_session(None).await.unwrap();
723+
session.start_transaction(None).await.unwrap();
724+
// transition transaction to "in progress" so that the commit
725+
// actually executes an operation.
726+
client
727+
.database("retry_commit_txn_check_out")
728+
.collection("retry_commit_txn_check_out")
729+
.insert_one_with_session(doc! {}, None, &mut session)
730+
.await
731+
.unwrap();
732+
733+
// enable a fail point that clears the connection pools so that
734+
// commitTransaction will create a new connection during check out.
735+
let fp = FailPoint::fail_command(
736+
&["ping"],
737+
FailPointMode::Times(1),
738+
FailCommandOptions::builder().error_code(11600).build(),
739+
);
740+
let _guard = setup_client.enable_failpoint(fp, None).await.unwrap();
741+
742+
let mut subscriber = handler.subscribe();
743+
client
744+
.database("foo")
745+
.run_command(doc! { "ping": 1 }, None)
746+
.await
747+
.unwrap_err();
748+
749+
// failing with a state change error will request an immediate check
750+
// wait for the mark unknown and subsequent succeeded heartbeat
751+
let mut primary = None;
752+
subscriber
753+
.wait_for_event(Duration::from_secs(1), |e| {
754+
if let Event::Sdam(SdamEvent::ServerDescriptionChanged(event)) = e {
755+
if event.is_marked_unknown_event() {
756+
primary = Some(event.address.clone());
757+
return true;
758+
}
759+
}
760+
false
761+
})
762+
.await
763+
.expect("should see marked unknown event");
764+
765+
subscriber
766+
.wait_for_event(Duration::from_secs(1), |e| {
767+
if let Event::Sdam(SdamEvent::ServerDescriptionChanged(event)) = e {
768+
if &event.address == primary.as_ref().unwrap()
769+
&& event.previous_description.server_type() == ServerType::Unknown
770+
{
771+
return true;
772+
}
773+
}
774+
false
775+
})
776+
.await
777+
.expect("should see mark available event");
778+
779+
// enable a failpoint on the handshake to cause check_out
780+
// to fail with a retryable error
781+
let fp = FailPoint::fail_command(
782+
&[LEGACY_HELLO_COMMAND_NAME, "hello"],
783+
FailPointMode::Times(1),
784+
FailCommandOptions::builder()
785+
.error_code(11600)
786+
.app_name("retry_commit_txn_check_out".to_string())
787+
.build(),
788+
);
789+
let _guard2 = setup_client.enable_failpoint(fp, None).await.unwrap();
790+
791+
// finally, attempt the commit.
792+
// this should succeed due to retry
793+
session.commit_transaction().await.unwrap();
794+
795+
// ensure the first check out attempt fails
796+
subscriber
797+
.wait_for_event(Duration::from_secs(1), |e| {
798+
matches!(e, Event::Cmap(CmapEvent::ConnectionCheckOutFailed(_)))
799+
})
800+
.await
801+
.expect("should see check out failed event");
802+
803+
// ensure the second one succeeds
804+
subscriber
805+
.wait_for_event(Duration::from_secs(1), |e| {
806+
matches!(e, Event::Cmap(CmapEvent::ConnectionCheckedOut(_)))
807+
})
808+
.await
809+
.expect("should see checked out event");
810+
}

src/test/util/event.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,9 @@ pub struct EventSubscriber<'a> {
389389
}
390390

391391
impl EventSubscriber<'_> {
392-
pub async fn wait_for_event<F>(&mut self, timeout: Duration, filter: F) -> Option<Event>
392+
pub async fn wait_for_event<F>(&mut self, timeout: Duration, mut filter: F) -> Option<Event>
393393
where
394-
F: Fn(&Event) -> bool,
394+
F: FnMut(&Event) -> bool,
395395
{
396396
runtime::timeout(timeout, async {
397397
loop {

src/test/util/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,21 @@ impl TestClient {
240240
version.matches(&self.server_version)
241241
}
242242

243+
/// Whether the deployment supports failing the initial handshake
244+
/// only when it uses a specified appName.
245+
///
246+
/// See SERVER-49336 for more info.
247+
pub fn supports_fail_command_appname_initial_handshake(&self) -> bool {
248+
let requirements = [
249+
VersionReq::parse(">= 4.2.15, < 4.3.0").unwrap(),
250+
VersionReq::parse(">= 4.4.7, < 4.5.0").unwrap(),
251+
VersionReq::parse(">= 4.9.0").unwrap(),
252+
];
253+
requirements
254+
.iter()
255+
.any(|req| req.matches(&self.server_version))
256+
}
257+
243258
pub fn supports_transactions(&self) -> bool {
244259
self.is_replica_set() && self.server_version_gte(4, 0)
245260
|| self.is_sharded() && self.server_version_gte(4, 2)

0 commit comments

Comments
 (0)