Skip to content

Commit 9fa1655

Browse files
authored
RUST-384 Close connection which was dropped during command execution (#214)
1 parent dd019cc commit 9fa1655

File tree

3 files changed

+81
-18
lines changed

3 files changed

+81
-18
lines changed

src/cmap/conn/mod.rs

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ pub(crate) struct Connection {
6161
/// currently checked into the pool, this will be None.
6262
pub(super) pool: Option<Weak<ConnectionPoolInner>>,
6363

64+
/// Whether or not a command is currently being run on this connection. This is set to `true`
65+
/// right before sending bytes to the server and set back to `false` once a full response has
66+
/// been read.
67+
command_executing: bool,
68+
6469
stream: AsyncStream,
6570

6671
#[derivative(Debug = "ignore")]
@@ -85,6 +90,7 @@ impl Connection {
8590
id,
8691
generation,
8792
pool: None,
93+
command_executing: false,
8894
ready_and_available_time: None,
8995
stream: AsyncStream::connect(stream_options).await?,
9096
address,
@@ -207,9 +213,13 @@ impl Connection {
207213
request_id: impl Into<Option<i32>>,
208214
) -> Result<CommandResponse> {
209215
let message = Message::with_command(command, request_id.into());
216+
217+
self.command_executing = true;
210218
message.write_to(&mut self.stream).await?;
211219

212220
let response_message = Message::read_from(&mut self.stream).await?;
221+
self.command_executing = false;
222+
213223
CommandResponse::new(self.address.clone(), response_message)
214224
}
215225

@@ -253,24 +263,30 @@ impl Connection {
253263

254264
impl Drop for Connection {
255265
fn drop(&mut self) {
256-
// If the connection has a weak reference to a pool, that means that the connection is being
257-
// dropped when it's checked out. If the pool is still alive, it should check itself back
258-
// in. Otherwise, the connection should close itself and emit a ConnectionClosed event
259-
// (because the `close_and_drop` helper was not called explicitly).
260-
//
261-
// If the connection does not have a weak reference to a pool, then the connection is being
262-
// dropped while it's not checked out. This means that the pool called the `close_and_drop`
263-
// helper explicitly, so we don't add it back to the pool or emit any events.
264-
if let Some(ref weak_pool_ref) = self.pool {
265-
if let Some(strong_pool_ref) = weak_pool_ref.upgrade() {
266-
let dropped_connection_state = self.take();
267-
RUNTIME.execute(async move {
268-
strong_pool_ref
269-
.check_in(dropped_connection_state.into())
270-
.await;
271-
});
272-
} else {
273-
self.close(ConnectionClosedReason::PoolClosed);
266+
if self.command_executing {
267+
self.close(ConnectionClosedReason::Dropped);
268+
} else {
269+
// If the connection has a weak reference to a pool, that means that the connection is
270+
// being dropped when it's checked out. If the pool is still alive, it
271+
// should check itself back in. Otherwise, the connection should close
272+
// itself and emit a ConnectionClosed event (because the `close_and_drop`
273+
// helper was not called explicitly).
274+
//
275+
// If the connection does not have a weak reference to a pool, then the connection is
276+
// being dropped while it's not checked out. This means that the pool called
277+
// the `close_and_drop` helper explicitly, so we don't add it back to the
278+
// pool or emit any events.
279+
if let Some(ref weak_pool_ref) = self.pool {
280+
if let Some(strong_pool_ref) = weak_pool_ref.upgrade() {
281+
let dropped_connection_state = self.take();
282+
RUNTIME.execute(async move {
283+
strong_pool_ref
284+
.check_in(dropped_connection_state.into())
285+
.await;
286+
});
287+
} else {
288+
self.close(ConnectionClosedReason::PoolClosed);
289+
}
274290
}
275291
}
276292
}
@@ -318,6 +334,7 @@ impl From<DroppedConnectionState> for Connection {
318334
id: state.id,
319335
address: state.address.clone(),
320336
generation: state.generation,
337+
command_executing: false,
321338
stream: std::mem::replace(&mut state.stream, AsyncStream::Null),
322339
handler: state.handler.take(),
323340
stream_description: state.stream_description.take(),

src/event/cmap.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ pub enum ConnectionClosedReason {
117117
/// An error occurred while using the connection.
118118
Error,
119119

120+
/// The connection was dropped during read or write.
121+
Dropped,
122+
120123
/// The pool that the connection belongs to has been closed.
121124
PoolClosed,
122125
}

src/test/client.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::{
99
selection_criteria::{ReadPreference, ReadPreferenceOptions, SelectionCriteria},
1010
test::{util::TestClient, CLIENT_OPTIONS, LOCK},
1111
Client,
12+
RUNTIME,
1213
};
1314

1415
#[derive(Debug, Deserialize)]
@@ -57,6 +58,48 @@ async fn metadata_sent_in_handshake() {
5758
assert_eq!(metadata.client.driver.name, "mrd");
5859
}
5960

61+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
62+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
63+
#[function_name::named]
64+
async fn connection_drop_during_read() {
65+
let _guard = LOCK.run_concurrently().await;
66+
67+
let options = CLIENT_OPTIONS.clone();
68+
69+
let client = Client::with_options(options.clone()).unwrap();
70+
let db = client.database("test");
71+
72+
db.collection(function_name!())
73+
.insert_one(doc! { "x": 1 }, None)
74+
.await
75+
.unwrap();
76+
77+
let _: Result<_, _> = RUNTIME
78+
.timeout(
79+
Duration::from_millis(50),
80+
db.run_command(
81+
doc! {
82+
"count": function_name!(),
83+
"query": {
84+
"$where": "sleep(100) && true"
85+
}
86+
},
87+
None,
88+
),
89+
)
90+
.await;
91+
92+
RUNTIME.delay_for(Duration::from_millis(200)).await;
93+
94+
let is_master_response = db.run_command(doc! { "isMaster": 1 }, None).await;
95+
96+
// Ensure that the response to `isMaster` is read, not the response to `count`.
97+
assert!(is_master_response
98+
.ok()
99+
.and_then(|value| value.get("ismaster").and_then(|value| value.as_bool()))
100+
.is_some());
101+
}
102+
60103
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
61104
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
62105
async fn server_selection_timeout_message() {

0 commit comments

Comments
 (0)