Skip to content

Commit 35497cb

Browse files
committed
storage: allow filtering browser sessions by which upstream session
authd them
1 parent 92c4cb7 commit 35497cb

File tree

4 files changed

+153
-2
lines changed

4 files changed

+153
-2
lines changed

crates/storage-pg/src/iden.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@ pub enum UserSessions {
1818
LastActiveIp,
1919
}
2020

21+
#[derive(sea_query::Iden)]
22+
#[expect(dead_code)]
23+
pub enum UserSessionAuthentications {
24+
Table,
25+
UserSessionAuthenticationId,
26+
UserSessionId,
27+
CreatedAt,
28+
UserPasswordId,
29+
#[iden = "upstream_oauth_authorization_session_id"]
30+
UpstreamOAuthAuthorizationSessionId,
31+
}
32+
2133
#[derive(sea_query::Iden)]
2234
pub enum Users {
2335
Table,

crates/storage-pg/src/user/session.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use mas_storage::{
1717
user::{BrowserSessionFilter, BrowserSessionRepository},
1818
};
1919
use rand::RngCore;
20-
use sea_query::{Expr, PostgresQueryBuilder};
20+
use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query};
2121
use sea_query_binder::SqlxBinder;
2222
use sqlx::PgConnection;
2323
use ulid::Ulid;
@@ -26,7 +26,7 @@ use uuid::Uuid;
2626
use crate::{
2727
DatabaseError, DatabaseInconsistencyError,
2828
filter::StatementExt,
29-
iden::{UserSessions, Users},
29+
iden::{UserSessionAuthentications, UserSessions, Users},
3030
pagination::QueryBuilderExt,
3131
tracing::ExecuteExt,
3232
};
@@ -145,6 +145,31 @@ impl crate::filter::Filter for BrowserSessionFilter<'_> {
145145
.add_option(self.last_active_before().map(|last_active_before| {
146146
Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
147147
}))
148+
.add_option(self.authenticated_by_upstream_sessions().map(|sessions| {
149+
// For filtering by upstream sessions, we need to hop over the
150+
// `user_session_authentications` table
151+
let session_ids: Vec<_> = sessions
152+
.iter()
153+
.map(|session| Uuid::from(session.id))
154+
.collect();
155+
156+
Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
157+
Query::select()
158+
.expr(Expr::col((
159+
UserSessionAuthentications::Table,
160+
UserSessionAuthentications::UserSessionId,
161+
)))
162+
.from(UserSessionAuthentications::Table)
163+
.and_where(
164+
Expr::col((
165+
UserSessionAuthentications::Table,
166+
UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
167+
))
168+
.eq(PgFunc::any(Expr::value(session_ids))),
169+
)
170+
.take(),
171+
)
172+
}))
148173
}
149174
}
150175

crates/storage-pg/src/user/tests.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
// Please see LICENSE files in the repository root for full details.
66

77
use chrono::Duration;
8+
use mas_iana::jose::JsonWebSignatureAlg;
89
use mas_storage::{
910
Clock, Pagination, RepositoryAccess,
1011
clock::MockClock,
12+
upstream_oauth2::UpstreamOAuthProviderParams,
1113
user::{
1214
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
1315
UserFilter, UserPasswordRepository, UserRepository,
1416
},
1517
};
18+
use oauth2_types::scope::{OPENID, Scope};
1619
use rand::SeedableRng;
1720
use rand_chacha::ChaChaRng;
1821
use sqlx::PgPool;
@@ -717,6 +720,97 @@ async fn test_user_session(pool: PgPool) {
717720
assert_eq!(repo.browser_session().count(all_bob).await.unwrap(), 5);
718721
assert_eq!(repo.browser_session().count(active_bob).await.unwrap(), 0);
719722
assert_eq!(repo.browser_session().count(finished).await.unwrap(), 11);
723+
724+
// Checking the 'authenticaated by upstream sessions' filter
725+
// We need a provider
726+
let provider = repo
727+
.upstream_oauth_provider()
728+
.add(
729+
&mut rng,
730+
&clock,
731+
UpstreamOAuthProviderParams {
732+
issuer: None,
733+
human_name: None,
734+
brand_name: None,
735+
scope: Scope::from_iter([OPENID]),
736+
token_endpoint_auth_method:
737+
mas_data_model::UpstreamOAuthProviderTokenAuthMethod::None,
738+
token_endpoint_signing_alg: None,
739+
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
740+
fetch_userinfo: false,
741+
userinfo_signed_response_alg: None,
742+
client_id: "client".to_owned(),
743+
encrypted_client_secret: None,
744+
claims_imports: mas_data_model::UpstreamOAuthProviderClaimsImports::default(),
745+
authorization_endpoint_override: None,
746+
token_endpoint_override: None,
747+
userinfo_endpoint_override: None,
748+
jwks_uri_override: None,
749+
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled,
750+
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Disabled,
751+
response_mode: None,
752+
additional_authorization_parameters: Vec::new(),
753+
forward_login_hint: false,
754+
ui_order: 0,
755+
on_backchannel_logout:
756+
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
757+
},
758+
)
759+
.await
760+
.unwrap();
761+
762+
// Start a authorization session
763+
let upstream_oauth_session = repo
764+
.upstream_oauth_session()
765+
.add(&mut rng, &clock, &provider, "state".to_owned(), None, None)
766+
.await
767+
.unwrap();
768+
769+
// Start a browser session
770+
let session = repo
771+
.browser_session()
772+
.add(&mut rng, &clock, &alice, None)
773+
.await
774+
.unwrap();
775+
776+
// Make the session from alice authenticated by this session
777+
repo.browser_session()
778+
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_oauth_session)
779+
.await
780+
.unwrap();
781+
782+
let session_list = vec![upstream_oauth_session];
783+
let filter = BrowserSessionFilter::new().authenticated_by_upstream_sessions_only(&session_list);
784+
785+
// Now try to look it up
786+
let page = repo
787+
.browser_session()
788+
.list(filter, Pagination::first(10))
789+
.await
790+
.unwrap();
791+
assert_eq!(page.edges.len(), 1);
792+
assert_eq!(page.edges[0].id, session.id);
793+
794+
// Try counting
795+
assert_eq!(repo.browser_session().count(filter).await.unwrap(), 1);
796+
797+
// Try finishing the session
798+
let affected = repo
799+
.browser_session()
800+
.finish_bulk(&clock, filter)
801+
.await
802+
.unwrap();
803+
assert_eq!(affected, 1);
804+
805+
// Lookup the session by its ID
806+
let lookup = repo
807+
.browser_session()
808+
.lookup(session.id)
809+
.await
810+
.unwrap()
811+
.expect("session to be found in the database");
812+
// It should be finished
813+
assert!(lookup.finished_at.is_some());
720814
}
721815

722816
#[sqlx::test(migrator = "crate::MIGRATOR")]

crates/storage/src/user/session.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pub struct BrowserSessionFilter<'a> {
3939
state: Option<BrowserSessionState>,
4040
last_active_before: Option<DateTime<Utc>>,
4141
last_active_after: Option<DateTime<Utc>>,
42+
authenticated_by_upstream_sessions: Option<&'a [UpstreamOAuthAuthorizationSession]>,
4243
}
4344

4445
impl<'a> BrowserSessionFilter<'a> {
@@ -110,6 +111,25 @@ impl<'a> BrowserSessionFilter<'a> {
110111
pub fn state(&self) -> Option<BrowserSessionState> {
111112
self.state
112113
}
114+
115+
/// Only return browser sessions authenticated by the given upstream OAuth
116+
/// sessions
117+
#[must_use]
118+
pub fn authenticated_by_upstream_sessions_only(
119+
mut self,
120+
upstream_oauth_sessions: &'a [UpstreamOAuthAuthorizationSession],
121+
) -> Self {
122+
self.authenticated_by_upstream_sessions = Some(upstream_oauth_sessions);
123+
self
124+
}
125+
126+
/// Get the upstream OAuth session filter
127+
#[must_use]
128+
pub fn authenticated_by_upstream_sessions(
129+
&self,
130+
) -> Option<&'a [UpstreamOAuthAuthorizationSession]> {
131+
self.authenticated_by_upstream_sessions
132+
}
113133
}
114134

115135
/// A [`BrowserSessionRepository`] helps interacting with [`BrowserSession`]

0 commit comments

Comments
 (0)