Skip to content

Commit 5db877e

Browse files
committed
storage: list and count methods for upstream oauth sessions
1 parent 9e6c9cb commit 5db877e

File tree

5 files changed

+444
-5
lines changed

5 files changed

+444
-5
lines changed

crates/storage-pg/src/iden.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,29 @@ pub enum UpstreamOAuthLinks {
140140
CreatedAt,
141141
}
142142

143+
#[derive(sea_query::Iden)]
144+
#[iden = "upstream_oauth_authorization_sessions"]
145+
pub enum UpstreamOAuthAuthorizationSessions {
146+
Table,
147+
#[iden = "upstream_oauth_authorization_session_id"]
148+
UpstreamOAuthAuthorizationSessionId,
149+
#[iden = "upstream_oauth_provider_id"]
150+
UpstreamOAuthProviderId,
151+
#[iden = "upstream_oauth_link_id"]
152+
UpstreamOAuthLinkId,
153+
State,
154+
CodeChallengeVerifier,
155+
Nonce,
156+
IdToken,
157+
IdTokenClaims,
158+
ExtraCallbackParameters,
159+
Userinfo,
160+
CreatedAt,
161+
CompletedAt,
162+
ConsumedAt,
163+
UnlinkedAt,
164+
}
165+
143166
#[derive(sea_query::Iden)]
144167
pub enum UserRegistrationTokens {
145168
Table,

crates/storage-pg/src/upstream_oauth2/mod.rs

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ mod tests {
2929
upstream_oauth2::{
3030
UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
3131
UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32-
UpstreamOAuthSessionRepository,
32+
UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
3333
},
3434
user::UserRepository,
3535
};
@@ -262,6 +262,29 @@ mod tests {
262262
1
263263
);
264264

265+
// Test listing and counting sessions
266+
let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
267+
268+
// Count the sessions for the provider
269+
let session_count = repo
270+
.upstream_oauth_session()
271+
.count(session_filter)
272+
.await
273+
.unwrap();
274+
assert_eq!(session_count, 1);
275+
276+
// List the sessions for the provider
277+
let session_page = repo
278+
.upstream_oauth_session()
279+
.list(session_filter, Pagination::first(10))
280+
.await
281+
.unwrap();
282+
283+
assert_eq!(session_page.edges.len(), 1);
284+
assert_eq!(session_page.edges[0].id, session.id);
285+
assert!(!session_page.has_next_page);
286+
assert!(!session_page.has_previous_page);
287+
265288
// Try deleting the provider
266289
repo.upstream_oauth_provider()
267290
.delete(provider)
@@ -423,4 +446,136 @@ mod tests {
423446
.is_empty()
424447
);
425448
}
449+
450+
/// Test that the pagination works as expected in the upstream OAuth
451+
/// session repository
452+
#[sqlx::test(migrator = "crate::MIGRATOR")]
453+
async fn test_session_repository_pagination(pool: PgPool) {
454+
let scope = Scope::from_iter([OPENID]);
455+
456+
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
457+
let clock = MockClock::default();
458+
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
459+
460+
// Create a provider
461+
let provider = repo
462+
.upstream_oauth_provider()
463+
.add(
464+
&mut rng,
465+
&clock,
466+
UpstreamOAuthProviderParams {
467+
issuer: Some("https://example.com/".to_owned()),
468+
human_name: None,
469+
brand_name: None,
470+
scope,
471+
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
472+
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
473+
fetch_userinfo: false,
474+
userinfo_signed_response_alg: None,
475+
token_endpoint_signing_alg: None,
476+
client_id: "client-id".to_owned(),
477+
encrypted_client_secret: None,
478+
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
479+
token_endpoint_override: None,
480+
authorization_endpoint_override: None,
481+
userinfo_endpoint_override: None,
482+
jwks_uri_override: None,
483+
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
484+
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
485+
response_mode: None,
486+
additional_authorization_parameters: Vec::new(),
487+
forward_login_hint: false,
488+
ui_order: 0,
489+
},
490+
)
491+
.await
492+
.unwrap();
493+
494+
let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
495+
496+
// Count the number of sessions before we start
497+
assert_eq!(
498+
repo.upstream_oauth_session().count(filter).await.unwrap(),
499+
0
500+
);
501+
502+
let mut ids = Vec::with_capacity(20);
503+
// Create 20 sessions
504+
for idx in 0..20 {
505+
let state = format!("state-{idx}");
506+
let session = repo
507+
.upstream_oauth_session()
508+
.add(&mut rng, &clock, &provider, state, None, None)
509+
.await
510+
.unwrap();
511+
ids.push(session.id);
512+
clock.advance(Duration::microseconds(10 * 1000 * 1000));
513+
}
514+
515+
// Now we have 20 sessions
516+
assert_eq!(
517+
repo.upstream_oauth_session().count(filter).await.unwrap(),
518+
20
519+
);
520+
521+
// Lookup the first 10 items
522+
let page = repo
523+
.upstream_oauth_session()
524+
.list(filter, Pagination::first(10))
525+
.await
526+
.unwrap();
527+
528+
// It returned the first 10 items
529+
assert!(page.has_next_page);
530+
let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
531+
assert_eq!(&edge_ids, &ids[..10]);
532+
533+
// Lookup the next 10 items
534+
let page = repo
535+
.upstream_oauth_session()
536+
.list(filter, Pagination::first(10).after(ids[9]))
537+
.await
538+
.unwrap();
539+
540+
// It returned the next 10 items
541+
assert!(!page.has_next_page);
542+
let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
543+
assert_eq!(&edge_ids, &ids[10..]);
544+
545+
// Lookup the last 10 items
546+
let page = repo
547+
.upstream_oauth_session()
548+
.list(filter, Pagination::last(10))
549+
.await
550+
.unwrap();
551+
552+
// It returned the last 10 items
553+
assert!(page.has_previous_page);
554+
let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
555+
assert_eq!(&edge_ids, &ids[10..]);
556+
557+
// Lookup the previous 10 items
558+
let page = repo
559+
.upstream_oauth_session()
560+
.list(filter, Pagination::last(10).before(ids[10]))
561+
.await
562+
.unwrap();
563+
564+
// It returned the previous 10 items
565+
assert!(!page.has_previous_page);
566+
let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
567+
assert_eq!(&edge_ids, &ids[..10]);
568+
569+
// Lookup 5 items between two IDs
570+
let page = repo
571+
.upstream_oauth_session()
572+
.list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
573+
.await
574+
.unwrap();
575+
576+
// It returned the items in between
577+
assert!(!page.has_next_page);
578+
let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
579+
assert_eq!(&edge_ids, &ids[6..11]);
580+
}
426581
}

0 commit comments

Comments
 (0)