Skip to content

Commit 8f404f6

Browse files
committed
Compose filters for batch logging out of browser sessions
Instead of having to load all authentication sessions in memory, we allow composing browser session filters with a upstream auth sessions filter
1 parent 212d2cc commit 8f404f6

File tree

5 files changed

+40
-49
lines changed

5 files changed

+40
-49
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ unsafe_code = "deny"
2424
# We use groups as good defaults, but with a lower priority so that we can override them
2525
all = { level = "deny", priority = -1 }
2626
pedantic = { level = "warn", priority = -1 }
27+
result_large_err = "warn"
28+
too_many_arguments = "warn"
2729

2830
str_to_string = "deny"
2931

crates/handlers/src/upstream_oauth2/backchannel_logout.rs

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use mas_oidc_client::{
2222
requests::jose::{JwtVerificationData, verify_signed_jwt},
2323
};
2424
use mas_storage::{
25-
BoxClock, BoxRepository, Pagination, upstream_oauth2::UpstreamOAuthSessionFilter,
25+
BoxClock, BoxRepository, upstream_oauth2::UpstreamOAuthSessionFilter,
2626
user::BrowserSessionFilter,
2727
};
2828
use oauth2_types::errors::{ClientError, ClientErrorCode};
@@ -222,41 +222,27 @@ pub(crate) async fn post(
222222
claims::NONCE.assert_absent(&claims)?; // (7)
223223

224224
// Find the corresponding upstream OAuth 2.0 sessions
225-
let mut filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
225+
let mut auth_session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
226226
if let Some(sub) = &sub {
227-
filter = filter.with_sub_claim(sub);
227+
auth_session_filter = auth_session_filter.with_sub_claim(sub);
228228
}
229229
if let Some(sid) = &sid {
230-
filter = filter.with_sid_claim(sid);
230+
auth_session_filter = auth_session_filter.with_sid_claim(sid);
231231
}
232+
let count = repo
233+
.upstream_oauth_session()
234+
.count(auth_session_filter)
235+
.await?;
232236

233-
// Load the corresponding authentication sessions, by batches of 100s. It's
234-
// VERY unlikely that we'll ever have more that 100 sessions for a single
235-
// logout notification, but we'll handle it anyway.
236-
let mut cursor = Pagination::first(100);
237-
let mut sessions = Vec::new();
238-
loop {
239-
let page = repo.upstream_oauth_session().list(filter, cursor).await?;
240-
241-
for session in page.edges {
242-
cursor = cursor.after(session.id);
243-
sessions.push(session);
244-
}
245-
246-
if !page.has_next_page {
247-
break;
248-
}
249-
}
250-
251-
tracing::info!(sub, sid, %provider.id, "Backchannel logout received, found {} corresponding authentication sessions", sessions.len());
237+
tracing::info!(sub, sid, %provider.id, "Backchannel logout received, found {count} corresponding authentication sessions");
252238

253239
match provider.on_backchannel_logout {
254240
UpstreamOAuthProviderOnBackchannelLogout::DoNothing => {
255241
tracing::warn!(%provider.id, "Provider configured to do nothing on backchannel logout");
256242
}
257243
UpstreamOAuthProviderOnBackchannelLogout::LogoutBrowserOnly => {
258-
let filter =
259-
BrowserSessionFilter::new().authenticated_by_upstream_sessions_only(&sessions);
244+
let filter = BrowserSessionFilter::new()
245+
.authenticated_by_upstream_sessions_only(auth_session_filter);
260246
let affected = repo.browser_session().finish_bulk(&clock, filter).await?;
261247
tracing::info!("Finished {affected} browser sessions");
262248
}

crates/storage-pg/src/user/session.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use mas_storage::{
1717
user::{BrowserSessionFilter, BrowserSessionRepository},
1818
};
1919
use rand::RngCore;
20-
use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query};
20+
use sea_query::{Expr, PostgresQueryBuilder, Query};
2121
use sea_query_binder::SqlxBinder;
2222
use sqlx::PgConnection;
2323
use ulid::Ulid;
@@ -26,7 +26,7 @@ use uuid::Uuid;
2626
use crate::{
2727
DatabaseError, DatabaseInconsistencyError,
2828
filter::StatementExt,
29-
iden::{UserSessionAuthentications, UserSessions, Users},
29+
iden::{UpstreamOAuthAuthorizationSessions, UserSessionAuthentications, UserSessions, Users},
3030
pagination::QueryBuilderExt,
3131
tracing::ExecuteExt,
3232
};
@@ -145,13 +145,17 @@ impl crate::filter::Filter for BrowserSessionFilter<'_> {
145145
.add_option(self.last_active_before().map(|last_active_before| {
146146
Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
147147
}))
148-
.add_option(self.authenticated_by_upstream_sessions().map(|sessions| {
148+
.add_option(self.authenticated_by_upstream_sessions().map(|filter| {
149149
// For filtering by upstream sessions, we need to hop over the
150150
// `user_session_authentications` table
151-
let session_ids: Vec<_> = sessions
152-
.iter()
153-
.map(|session| Uuid::from(session.id))
154-
.collect();
151+
let join_expr = Expr::col((
152+
UserSessionAuthentications::Table,
153+
UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
154+
))
155+
.eq(Expr::col((
156+
UpstreamOAuthAuthorizationSessions::Table,
157+
UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
158+
)));
155159

156160
Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
157161
Query::select()
@@ -160,13 +164,8 @@ impl crate::filter::Filter for BrowserSessionFilter<'_> {
160164
UserSessionAuthentications::UserSessionId,
161165
)))
162166
.from(UserSessionAuthentications::Table)
163-
.and_where(
164-
Expr::col((
165-
UserSessionAuthentications::Table,
166-
UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
167-
))
168-
.eq(PgFunc::any(Expr::value(session_ids))),
169-
)
167+
.inner_join(UpstreamOAuthAuthorizationSessions::Table, join_expr)
168+
.apply_filter(filter)
170169
.take(),
171170
)
172171
}))

crates/storage-pg/src/user/tests.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use mas_iana::jose::JsonWebSignatureAlg;
99
use mas_storage::{
1010
Clock, Pagination, RepositoryAccess,
1111
clock::MockClock,
12-
upstream_oauth2::UpstreamOAuthProviderParams,
12+
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthSessionFilter},
1313
user::{
1414
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
1515
UserFilter, UserPasswordRepository, UserRepository,
@@ -779,8 +779,11 @@ async fn test_user_session(pool: PgPool) {
779779
.await
780780
.unwrap();
781781

782-
let session_list = vec![upstream_oauth_session];
783-
let filter = BrowserSessionFilter::new().authenticated_by_upstream_sessions_only(&session_list);
782+
// This will match all authorization sessions, which matches exactly that one
783+
// authorization session
784+
let upstream_oauth_session_filter = UpstreamOAuthSessionFilter::new();
785+
let filter = BrowserSessionFilter::new()
786+
.authenticated_by_upstream_sessions_only(upstream_oauth_session_filter);
784787

785788
// Now try to look it up
786789
let page = repo

crates/storage/src/user/session.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ use mas_data_model::{
1414
use rand_core::RngCore;
1515
use ulid::Ulid;
1616

17-
use crate::{Clock, Pagination, pagination::Page, repository_impl};
17+
use crate::{
18+
Clock, Pagination, pagination::Page, repository_impl,
19+
upstream_oauth2::UpstreamOAuthSessionFilter,
20+
};
1821

1922
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2023
pub enum BrowserSessionState {
@@ -39,7 +42,7 @@ pub struct BrowserSessionFilter<'a> {
3942
state: Option<BrowserSessionState>,
4043
last_active_before: Option<DateTime<Utc>>,
4144
last_active_after: Option<DateTime<Utc>>,
42-
authenticated_by_upstream_sessions: Option<&'a [UpstreamOAuthAuthorizationSession]>,
45+
authenticated_by_upstream_sessions: Option<UpstreamOAuthSessionFilter<'a>>,
4346
}
4447

4548
impl<'a> BrowserSessionFilter<'a> {
@@ -117,17 +120,15 @@ impl<'a> BrowserSessionFilter<'a> {
117120
#[must_use]
118121
pub fn authenticated_by_upstream_sessions_only(
119122
mut self,
120-
upstream_oauth_sessions: &'a [UpstreamOAuthAuthorizationSession],
123+
filter: UpstreamOAuthSessionFilter<'a>,
121124
) -> Self {
122-
self.authenticated_by_upstream_sessions = Some(upstream_oauth_sessions);
125+
self.authenticated_by_upstream_sessions = Some(filter);
123126
self
124127
}
125128

126129
/// Get the upstream OAuth session filter
127130
#[must_use]
128-
pub fn authenticated_by_upstream_sessions(
129-
&self,
130-
) -> Option<&'a [UpstreamOAuthAuthorizationSession]> {
131+
pub fn authenticated_by_upstream_sessions(&self) -> Option<UpstreamOAuthSessionFilter<'a>> {
131132
self.authenticated_by_upstream_sessions
132133
}
133134
}

0 commit comments

Comments
 (0)