@@ -29,7 +29,7 @@ mod tests {
2929 upstream_oauth2:: {
3030 UpstreamOAuthLinkFilter , UpstreamOAuthLinkRepository , UpstreamOAuthProviderFilter ,
3131 UpstreamOAuthProviderParams , UpstreamOAuthProviderRepository ,
32- UpstreamOAuthSessionRepository ,
32+ UpstreamOAuthSessionFilter , UpstreamOAuthSessionRepository ,
3333 } ,
3434 user:: UserRepository ,
3535 } ;
@@ -262,6 +262,29 @@ mod tests {
262262 1
263263 ) ;
264264
265+ // Test listing and counting sessions
266+ let session_filter = UpstreamOAuthSessionFilter :: new ( ) . for_provider ( & provider) ;
267+
268+ // Count the sessions for the provider
269+ let session_count = repo
270+ . upstream_oauth_session ( )
271+ . count ( session_filter)
272+ . await
273+ . unwrap ( ) ;
274+ assert_eq ! ( session_count, 1 ) ;
275+
276+ // List the sessions for the provider
277+ let session_page = repo
278+ . upstream_oauth_session ( )
279+ . list ( session_filter, Pagination :: first ( 10 ) )
280+ . await
281+ . unwrap ( ) ;
282+
283+ assert_eq ! ( session_page. edges. len( ) , 1 ) ;
284+ assert_eq ! ( session_page. edges[ 0 ] . id, session. id) ;
285+ assert ! ( !session_page. has_next_page) ;
286+ assert ! ( !session_page. has_previous_page) ;
287+
265288 // Try deleting the provider
266289 repo. upstream_oauth_provider ( )
267290 . delete ( provider)
@@ -423,4 +446,136 @@ mod tests {
423446 . is_empty( )
424447 ) ;
425448 }
449+
450+ /// Test that the pagination works as expected in the upstream OAuth
451+ /// session repository
452+ #[ sqlx:: test( migrator = "crate::MIGRATOR" ) ]
453+ async fn test_session_repository_pagination ( pool : PgPool ) {
454+ let scope = Scope :: from_iter ( [ OPENID ] ) ;
455+
456+ let mut rng = rand_chacha:: ChaChaRng :: seed_from_u64 ( 42 ) ;
457+ let clock = MockClock :: default ( ) ;
458+ let mut repo = PgRepository :: from_pool ( & pool) . await . unwrap ( ) ;
459+
460+ // Create a provider
461+ let provider = repo
462+ . upstream_oauth_provider ( )
463+ . add (
464+ & mut rng,
465+ & clock,
466+ UpstreamOAuthProviderParams {
467+ issuer : Some ( "https://example.com/" . to_owned ( ) ) ,
468+ human_name : None ,
469+ brand_name : None ,
470+ scope,
471+ token_endpoint_auth_method : UpstreamOAuthProviderTokenAuthMethod :: None ,
472+ id_token_signed_response_alg : JsonWebSignatureAlg :: Rs256 ,
473+ fetch_userinfo : false ,
474+ userinfo_signed_response_alg : None ,
475+ token_endpoint_signing_alg : None ,
476+ client_id : "client-id" . to_owned ( ) ,
477+ encrypted_client_secret : None ,
478+ claims_imports : UpstreamOAuthProviderClaimsImports :: default ( ) ,
479+ token_endpoint_override : None ,
480+ authorization_endpoint_override : None ,
481+ userinfo_endpoint_override : None ,
482+ jwks_uri_override : None ,
483+ discovery_mode : mas_data_model:: UpstreamOAuthProviderDiscoveryMode :: Oidc ,
484+ pkce_mode : mas_data_model:: UpstreamOAuthProviderPkceMode :: Auto ,
485+ response_mode : None ,
486+ additional_authorization_parameters : Vec :: new ( ) ,
487+ forward_login_hint : false ,
488+ ui_order : 0 ,
489+ } ,
490+ )
491+ . await
492+ . unwrap ( ) ;
493+
494+ let filter = UpstreamOAuthSessionFilter :: new ( ) . for_provider ( & provider) ;
495+
496+ // Count the number of sessions before we start
497+ assert_eq ! (
498+ repo. upstream_oauth_session( ) . count( filter) . await . unwrap( ) ,
499+ 0
500+ ) ;
501+
502+ let mut ids = Vec :: with_capacity ( 20 ) ;
503+ // Create 20 sessions
504+ for idx in 0 ..20 {
505+ let state = format ! ( "state-{idx}" ) ;
506+ let session = repo
507+ . upstream_oauth_session ( )
508+ . add ( & mut rng, & clock, & provider, state, None , None )
509+ . await
510+ . unwrap ( ) ;
511+ ids. push ( session. id ) ;
512+ clock. advance ( Duration :: microseconds ( 10 * 1000 * 1000 ) ) ;
513+ }
514+
515+ // Now we have 20 sessions
516+ assert_eq ! (
517+ repo. upstream_oauth_session( ) . count( filter) . await . unwrap( ) ,
518+ 20
519+ ) ;
520+
521+ // Lookup the first 10 items
522+ let page = repo
523+ . upstream_oauth_session ( )
524+ . list ( filter, Pagination :: first ( 10 ) )
525+ . await
526+ . unwrap ( ) ;
527+
528+ // It returned the first 10 items
529+ assert ! ( page. has_next_page) ;
530+ let edge_ids: Vec < _ > = page. edges . iter ( ) . map ( |s| s. id ) . collect ( ) ;
531+ assert_eq ! ( & edge_ids, & ids[ ..10 ] ) ;
532+
533+ // Lookup the next 10 items
534+ let page = repo
535+ . upstream_oauth_session ( )
536+ . list ( filter, Pagination :: first ( 10 ) . after ( ids[ 9 ] ) )
537+ . await
538+ . unwrap ( ) ;
539+
540+ // It returned the next 10 items
541+ assert ! ( !page. has_next_page) ;
542+ let edge_ids: Vec < _ > = page. edges . iter ( ) . map ( |s| s. id ) . collect ( ) ;
543+ assert_eq ! ( & edge_ids, & ids[ 10 ..] ) ;
544+
545+ // Lookup the last 10 items
546+ let page = repo
547+ . upstream_oauth_session ( )
548+ . list ( filter, Pagination :: last ( 10 ) )
549+ . await
550+ . unwrap ( ) ;
551+
552+ // It returned the last 10 items
553+ assert ! ( page. has_previous_page) ;
554+ let edge_ids: Vec < _ > = page. edges . iter ( ) . map ( |s| s. id ) . collect ( ) ;
555+ assert_eq ! ( & edge_ids, & ids[ 10 ..] ) ;
556+
557+ // Lookup the previous 10 items
558+ let page = repo
559+ . upstream_oauth_session ( )
560+ . list ( filter, Pagination :: last ( 10 ) . before ( ids[ 10 ] ) )
561+ . await
562+ . unwrap ( ) ;
563+
564+ // It returned the previous 10 items
565+ assert ! ( !page. has_previous_page) ;
566+ let edge_ids: Vec < _ > = page. edges . iter ( ) . map ( |s| s. id ) . collect ( ) ;
567+ assert_eq ! ( & edge_ids, & ids[ ..10 ] ) ;
568+
569+ // Lookup 5 items between two IDs
570+ let page = repo
571+ . upstream_oauth_session ( )
572+ . list ( filter, Pagination :: first ( 10 ) . after ( ids[ 5 ] ) . before ( ids[ 11 ] ) )
573+ . await
574+ . unwrap ( ) ;
575+
576+ // It returned the items in between
577+ assert ! ( !page. has_next_page) ;
578+ let edge_ids: Vec < _ > = page. edges . iter ( ) . map ( |s| s. id ) . collect ( ) ;
579+ assert_eq ! ( & edge_ids, & ids[ 6 ..11 ] ) ;
580+ }
426581}
0 commit comments