Skip to content

Commit 3315a74

Browse files
RUST-2020 Ignore speculative authentication on reauthentication (#1320)
1 parent 87457fc commit 3315a74

File tree

1 file changed

+131
-1
lines changed

1 file changed

+131
-1
lines changed

src/test/spec/oidc.rs

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ mod basic {
353353
}
354354

355355
#[tokio::test(flavor = "multi_thread")]
356-
async fn machine_4_reauthentication() -> anyhow::Result<()> {
356+
async fn machine_4_1_reauthentication() -> anyhow::Result<()> {
357357
let admin_client = Client::with_uri_str(&*MONGODB_URI).await?;
358358

359359
// Now set a failpoint for find with 391 error code
@@ -393,6 +393,136 @@ mod basic {
393393
Ok(())
394394
}
395395

396+
#[tokio::test(flavor = "multi_thread")]
397+
async fn machine_4_2_read_command_fails_if_reauth_fails() -> anyhow::Result<()> {
398+
let call_count = Arc::new(Mutex::new(0));
399+
let cb_call_count = call_count.clone();
400+
401+
let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?;
402+
let credential = Credential::builder()
403+
.mechanism(AuthMechanism::MongoDbOidc)
404+
.oidc_callback(oidc::Callback::machine(move |_| {
405+
let call_count = cb_call_count.clone();
406+
async move {
407+
*call_count.lock().await += 1;
408+
let access_token = if *call_count.lock().await == 1 {
409+
get_access_token_test_user_1().await
410+
} else {
411+
"bad token".to_string()
412+
};
413+
Ok(oidc::IdpServerResponse::builder()
414+
.access_token(access_token)
415+
.build())
416+
}
417+
.boxed()
418+
}))
419+
.build();
420+
options.credential = Some(credential);
421+
let client = Client::with_options(options)?;
422+
let collection = client.database("test").collection::<Document>("test");
423+
424+
collection.find_one(doc! {}).await?;
425+
426+
let fail_point =
427+
FailPoint::fail_command(&["find"], FailPointMode::Times(1)).error_code(391);
428+
let _guard = client.enable_fail_point(fail_point).await?;
429+
430+
collection.find_one(doc! {}).await.unwrap_err();
431+
432+
assert_eq!(*call_count.lock().await, 2);
433+
434+
Ok(())
435+
}
436+
437+
#[tokio::test(flavor = "multi_thread")]
438+
async fn machine_4_3_write_command_fails_if_reauth_fails() -> anyhow::Result<()> {
439+
let call_count = Arc::new(Mutex::new(0));
440+
let cb_call_count = call_count.clone();
441+
442+
let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?;
443+
let credential = Credential::builder()
444+
.mechanism(AuthMechanism::MongoDbOidc)
445+
.oidc_callback(oidc::Callback::machine(move |_| {
446+
let call_count = cb_call_count.clone();
447+
async move {
448+
*call_count.lock().await += 1;
449+
let access_token = if *call_count.lock().await == 1 {
450+
get_access_token_test_user_1().await
451+
} else {
452+
"bad token".to_string()
453+
};
454+
Ok(oidc::IdpServerResponse::builder()
455+
.access_token(access_token)
456+
.build())
457+
}
458+
.boxed()
459+
}))
460+
.build();
461+
options.credential = Some(credential);
462+
let client = Client::with_options(options)?;
463+
let collection = client.database("test").collection::<Document>("test");
464+
465+
collection.insert_one(doc! { "x": 1 }).await?;
466+
467+
let fail_point =
468+
FailPoint::fail_command(&["insert"], FailPointMode::Times(1)).error_code(391);
469+
let _guard = client.enable_fail_point(fail_point).await?;
470+
471+
collection.insert_one(doc! { "y": 2 }).await.unwrap_err();
472+
473+
assert_eq!(*call_count.lock().await, 2);
474+
475+
Ok(())
476+
}
477+
478+
#[tokio::test(flavor = "multi_thread")]
479+
async fn machine_4_4_speculative_auth_ignored_on_reauth() -> anyhow::Result<()> {
480+
let call_count = Arc::new(Mutex::new(0));
481+
let cb_call_count = call_count.clone();
482+
483+
let mut options = ClientOptions::parse(&*MONGODB_URI_SINGLE).await?;
484+
let credential = Credential::builder()
485+
.mechanism(AuthMechanism::MongoDbOidc)
486+
.oidc_callback(oidc::Callback::machine(move |_| {
487+
let call_count = cb_call_count.clone();
488+
async move {
489+
*call_count.lock().await += 1;
490+
Ok(oidc::IdpServerResponse::builder()
491+
.access_token(get_access_token_test_user_1().await)
492+
.build())
493+
}
494+
.boxed()
495+
}))
496+
.build();
497+
credential
498+
.oidc_callback
499+
.set_access_token(Some(get_access_token_test_user_1().await))
500+
.await;
501+
options.credential = Some(credential);
502+
let client = Client::for_test().options(options).monitor_events().await;
503+
let event_buffer = &client.events;
504+
let collection = client.database("test").collection::<Document>("test");
505+
506+
collection.insert_one(doc! { "x": 1 }).await?;
507+
508+
assert_eq!(*call_count.lock().await, 0);
509+
let sasl_start_events = event_buffer.get_command_started_events(&["saslStart"]);
510+
assert!(sasl_start_events.is_empty());
511+
512+
let fail_point =
513+
FailPoint::fail_command(&["insert"], FailPointMode::Times(1)).error_code(391);
514+
let _guard = client.enable_fail_point(fail_point).await?;
515+
516+
collection.insert_one(doc! { "y": 2 }).await?;
517+
518+
assert_eq!(*call_count.lock().await, 1);
519+
let _sasl_start_events = event_buffer.get_command_started_events(&["saslStart"]);
520+
// TODO RUST-2176: unskip this assertion when saslStart events are emitted
521+
// assert!(!sasl_start_events.is_empty());
522+
523+
Ok(())
524+
}
525+
396526
// Human Callback tests
397527
#[tokio::test]
398528
async fn human_1_1_single_principal_implicit_username() -> anyhow::Result<()> {

0 commit comments

Comments
 (0)