33// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
44// Please see LICENSE files in the repository root for full details.
55
6- use std:: collections:: HashMap ;
6+ use std:: collections:: { HashMap , HashSet } ;
77
88use axum:: {
99 Form , Json ,
@@ -22,7 +22,11 @@ use mas_oidc_client::{
2222 requests:: jose:: { JwtVerificationData , verify_signed_jwt} ,
2323} ;
2424use mas_storage:: {
25- BoxClock , BoxRepository , upstream_oauth2:: UpstreamOAuthSessionFilter ,
25+ BoxClock , BoxRepository , BoxRng , Pagination ,
26+ compat:: CompatSessionFilter ,
27+ oauth2:: OAuth2SessionFilter ,
28+ queue:: { QueueJobRepositoryExt as _, SyncDevicesJob } ,
29+ upstream_oauth2:: UpstreamOAuthSessionFilter ,
2630 user:: BrowserSessionFilter ,
2731} ;
2832use oauth2_types:: errors:: { ClientError , ClientErrorCode } ;
@@ -131,6 +135,7 @@ const EVENTS: Claim<LogoutTokenEvents> = Claim::new("events");
131135) ]
132136pub ( crate ) async fn post (
133137 clock : BoxClock ,
138+ mut rng : BoxRng ,
134139 mut repo : BoxRepository ,
135140 State ( metadata_cache) : State < MetadataCache > ,
136141 State ( client) : State < reqwest:: Client > ,
@@ -242,10 +247,67 @@ pub(crate) async fn post(
242247 }
243248 UpstreamOAuthProviderOnBackchannelLogout :: LogoutBrowserOnly => {
244249 let filter = BrowserSessionFilter :: new ( )
245- . authenticated_by_upstream_sessions_only ( auth_session_filter) ;
250+ . authenticated_by_upstream_sessions_only ( auth_session_filter)
251+ . active_only ( ) ;
246252 let affected = repo. browser_session ( ) . finish_bulk ( & clock, filter) . await ?;
247253 tracing:: info!( "Finished {affected} browser sessions" ) ;
248254 }
255+ UpstreamOAuthProviderOnBackchannelLogout :: LogoutAll => {
256+ let browser_session_filter = BrowserSessionFilter :: new ( )
257+ . authenticated_by_upstream_sessions_only ( auth_session_filter) ;
258+
259+ // We need to loop through all the browser sessions to find all the
260+ // users affected so that we can trigger a device sync job for them
261+ let mut cursor = Pagination :: first ( 1000 ) ;
262+ let mut user_ids = HashSet :: new ( ) ;
263+ loop {
264+ let browser_sessions = repo
265+ . browser_session ( )
266+ . list ( browser_session_filter, cursor)
267+ . await ?;
268+ for browser_session in browser_sessions. edges {
269+ user_ids. insert ( browser_session. user . id ) ;
270+ cursor = cursor. after ( browser_session. id ) ;
271+ }
272+
273+ if !browser_sessions. has_next_page {
274+ break ;
275+ }
276+ }
277+
278+ let browser_sessions_affected = repo
279+ . browser_session ( )
280+ . finish_bulk ( & clock, browser_session_filter. active_only ( ) )
281+ . await ?;
282+
283+ let oauth2_session_filter = OAuth2SessionFilter :: new ( )
284+ . active_only ( )
285+ . for_browser_sessions ( browser_session_filter) ;
286+
287+ let oauth2_sessions_affected = repo
288+ . oauth2_session ( )
289+ . finish_bulk ( & clock, oauth2_session_filter)
290+ . await ?;
291+
292+ let compat_session_filter = CompatSessionFilter :: new ( )
293+ . active_only ( )
294+ . for_browser_sessions ( browser_session_filter) ;
295+
296+ let compat_sessions_affected = repo
297+ . compat_session ( )
298+ . finish_bulk ( & clock, compat_session_filter)
299+ . await ?;
300+
301+ tracing:: info!(
302+ "Finished {browser_sessions_affected} browser sessions, {oauth2_sessions_affected} OAuth 2.0 sessions and {compat_sessions_affected} compatibility sessions"
303+ ) ;
304+
305+ for user_id in user_ids {
306+ tracing:: info!( user. id = %user_id, "Queueing a device sync job for user" ) ;
307+ let job = SyncDevicesJob :: new_for_id ( user_id) ;
308+ repo. queue_job ( ) . schedule_job ( & mut rng, & clock, job) . await ?;
309+ }
310+ }
249311 }
250312
251313 repo. save ( ) . await ?;
0 commit comments