Skip to content

Commit 77b6f01

Browse files
RUST-945 Check that explicit sessions were created on the correct client (#405)
1 parent 295f3be commit 77b6f01

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/client/executor.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,15 @@ impl Client {
7272
}
7373
match session.into() {
7474
Some(session) => {
75+
if !Arc::ptr_eq(&self.inner, &session.client().inner) {
76+
return Err(ErrorKind::InvalidArgument {
77+
message: "the session provided to an operation must be created from the \
78+
same client as the collection/database"
79+
.into(),
80+
}
81+
.into());
82+
}
83+
7584
if let Some(SelectionCriteria::ReadPreference(read_preference)) =
7685
op.selection_criteria()
7786
{

src/test/spec/sessions.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
use tokio::sync::RwLockWriteGuard;
1+
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
22

3-
use crate::test::{run_spec_test, LOCK};
3+
use crate::{
4+
bson::doc,
5+
error::ErrorKind,
6+
test::{run_spec_test, TestClient, LOCK},
7+
};
48

59
use super::{run_unified_format_test, run_v2_test};
610

@@ -17,3 +21,38 @@ async fn run_legacy() {
1721
let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await;
1822
run_spec_test(&["sessions", "legacy"], run_v2_test).await;
1923
}
24+
25+
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
26+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
27+
#[function_name::named]
28+
async fn explicit_session_created_on_same_client() {
29+
let _guard: RwLockReadGuard<_> = LOCK.run_concurrently().await;
30+
31+
let client0 = TestClient::new().await;
32+
let client1 = TestClient::new().await;
33+
34+
let mut session0 = client0.start_session(None).await.unwrap();
35+
let mut session1 = client1.start_session(None).await.unwrap();
36+
37+
let db = client0.database(function_name!());
38+
let err = db
39+
.list_collections_with_session(None, None, &mut session1)
40+
.await
41+
.unwrap_err();
42+
match *err.kind {
43+
ErrorKind::InvalidArgument { message } => assert!(message.contains("session provided")),
44+
other => panic!("expected InvalidArgument error, got {:?}", other),
45+
}
46+
47+
let coll = client1
48+
.database(function_name!())
49+
.collection(function_name!());
50+
let err = coll
51+
.insert_one_with_session(doc! {}, None, &mut session0)
52+
.await
53+
.unwrap_err();
54+
match *err.kind {
55+
ErrorKind::InvalidArgument { message } => assert!(message.contains("session provided")),
56+
other => panic!("expected InvalidArgument error, got {:?}", other),
57+
}
58+
}

0 commit comments

Comments
 (0)