Skip to content

Commit 45736f2

Browse files
authored
RUST-1846 Test that saslSupportedMechs can contain arbitrary strings (#1144)
1 parent 08e2923 commit 45736f2

File tree

4 files changed

+59
-3
lines changed

4 files changed

+59
-3
lines changed

src/cmap/establish.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,17 @@ pub(crate) struct ConnectionEstablisher {
3030
tls_config: Option<TlsConfig>,
3131

3232
connect_timeout: Duration,
33+
34+
#[cfg(test)]
35+
test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
3336
}
3437

3538
pub(crate) struct EstablisherOptions {
3639
handshake_options: HandshakerOptions,
3740
tls_options: Option<TlsOptions>,
3841
connect_timeout: Option<Duration>,
42+
#[cfg(test)]
43+
pub(crate) test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
3944
}
4045

4146
impl EstablisherOptions {
@@ -55,6 +60,8 @@ impl EstablisherOptions {
5560
},
5661
tls_options: opts.tls_options(),
5762
connect_timeout: opts.connect_timeout,
63+
#[cfg(test)]
64+
test_patch_reply: None,
5865
}
5966
}
6067
}
@@ -80,6 +87,8 @@ impl ConnectionEstablisher {
8087
handshaker,
8188
tls_config,
8289
connect_timeout,
90+
#[cfg(test)]
91+
test_patch_reply: options.test_patch_reply,
8392
})
8493
}
8594

@@ -92,7 +101,7 @@ impl ConnectionEstablisher {
92101
}
93102

94103
/// Establishes a connection.
95-
pub(super) async fn establish_connection(
104+
pub(crate) async fn establish_connection(
96105
&self,
97106
pending_connection: PendingConnection,
98107
credential: Option<&Credential>,
@@ -106,7 +115,13 @@ impl ConnectionEstablisher {
106115
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;
107116

108117
let mut connection = Connection::new_pooled(pending_connection, stream);
109-
let handshake_result = self.handshaker.handshake(&mut connection, credential).await;
118+
#[allow(unused_mut)]
119+
let mut handshake_result = self.handshaker.handshake(&mut connection, credential).await;
120+
#[cfg(test)]
121+
if let Some(patch) = self.test_patch_reply {
122+
patch(&mut handshake_result);
123+
}
124+
let handshake_result = handshake_result;
110125

111126
// If the handshake response had a `serviceId` field, this is a connection to a load
112127
// balancer and must derive its generation from the service_generations map.

src/test/spec.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod connection_stepdown;
88
mod crud;
99
mod faas;
1010
mod gridfs;
11+
mod handshake;
1112
mod index_management;
1213
#[cfg(feature = "dns-resolver")]
1314
mod initial_dns_seedlist_discovery;

src/test/spec/handshake.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use std::time::Instant;
2+
3+
use bson::oid::ObjectId;
4+
5+
use crate::{
6+
cmap::{
7+
conn::PendingConnection,
8+
establish::{ConnectionEstablisher, EstablisherOptions},
9+
},
10+
event::cmap::CmapEventEmitter,
11+
test::get_client_options,
12+
};
13+
14+
// Prose test 1: Test that the driver accepts an arbitrary auth mechanism
15+
#[tokio::test]
16+
async fn arbitrary_auth_mechanism() {
17+
let client_options = get_client_options().await;
18+
let mut options = EstablisherOptions::from_client_options(client_options);
19+
options.test_patch_reply = Some(|reply| {
20+
reply
21+
.as_mut()
22+
.unwrap()
23+
.command_response
24+
.sasl_supported_mechs
25+
.get_or_insert_with(Vec::new)
26+
.push("ArBiTrArY!".to_string());
27+
});
28+
let establisher = ConnectionEstablisher::new(options).unwrap();
29+
let pending = PendingConnection {
30+
id: 0,
31+
address: client_options.hosts[0].clone(),
32+
generation: crate::cmap::PoolGeneration::normal(),
33+
event_emitter: CmapEventEmitter::new(None, ObjectId::new()),
34+
time_created: Instant::now(),
35+
};
36+
establisher
37+
.establish_connection(pending, None)
38+
.await
39+
.unwrap();
40+
}

src/test/spec/unified_runner/test_runner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ impl TestRunner {
746746
fn fill_kms_placeholders(
747747
kms_provider_map: HashMap<mongocrypt::ctx::KmsProvider, Document>,
748748
) -> crate::test::csfle::KmsProviderList {
749-
use crate::{bson::doc, test::csfle::ALL_KMS_PROVIDERS};
749+
use crate::test::csfle::ALL_KMS_PROVIDERS;
750750

751751
let placeholder = doc! { "$$placeholder": 1 };
752752
let all_kms_providers = ALL_KMS_PROVIDERS.clone();

0 commit comments

Comments
 (0)