Skip to content

Commit 34f7e49

Browse files
committed
Setup a job to expire OAuth 2.0 sessions
1 parent 917f4d1 commit 34f7e49

File tree

3 files changed

+153
-1
lines changed

3 files changed

+153
-1
lines changed

crates/storage/src/queue/tasks.rs

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
// SPDX-License-Identifier: AGPL-3.0-only
44
// Please see LICENSE in the repository root for full details.
55

6-
use mas_data_model::{Device, User, UserEmailAuthentication, UserRecoverySession};
6+
use chrono::{DateTime, Utc};
7+
use mas_data_model::{Device, Session, User, UserEmailAuthentication, UserRecoverySession};
78
use serde::{Deserialize, Serialize};
89
use ulid::Ulid;
910

1011
use super::InsertableJob;
12+
use crate::{Page, Pagination};
1113

1214
/// This is the previous iteration of the email verification job. It has been
1315
/// replaced by [`SendEmailAuthenticationCodeJob`]. This struct is kept to be
@@ -193,6 +195,15 @@ impl SyncDevicesJob {
193195
Self { user_id: user.id }
194196
}
195197

198+
/// Create a new job to sync the list of devices of a user with the
199+
/// homeserver for the given user ID
200+
///
201+
/// This is useful to use in cases where the [`User`] object isn't loaded
202+
#[must_use]
203+
pub fn new_for_id(user_id: Ulid) -> Self {
204+
Self { user_id }
205+
}
206+
196207
/// The ID of the user to sync the devices for
197208
#[must_use]
198209
pub fn user_id(&self) -> Ulid {
@@ -310,3 +321,60 @@ pub struct CleanupExpiredTokensJob;
310321
impl InsertableJob for CleanupExpiredTokensJob {
311322
const QUEUE_NAME: &'static str = "cleanup-expired-tokens";
312323
}
324+
325+
/// Expire inactive OAuth 2.0 sessions
326+
#[derive(Serialize, Deserialize, Debug, Clone)]
327+
pub struct ExpireInactiveOAuthSessionsJob {
328+
threshold: DateTime<Utc>,
329+
after: Option<Ulid>,
330+
}
331+
332+
impl ExpireInactiveOAuthSessionsJob {
333+
/// Create a new job to expire inactive OAuth 2.0 sessions
334+
///
335+
/// # Parameters
336+
///
337+
/// * `threshold` - The threshold to expire sessions at
338+
#[must_use]
339+
pub fn new(threshold: DateTime<Utc>) -> Self {
340+
Self {
341+
threshold,
342+
after: None,
343+
}
344+
}
345+
346+
/// Get the threshold to expire sessions at
347+
#[must_use]
348+
pub fn threshold(&self) -> DateTime<Utc> {
349+
self.threshold
350+
}
351+
352+
/// Get the pagination cursor
353+
#[must_use]
354+
pub fn pagination(&self, batch_size: usize) -> Pagination {
355+
let pagination = Pagination::first(batch_size);
356+
if let Some(after) = self.after {
357+
pagination.after(after)
358+
} else {
359+
pagination
360+
}
361+
}
362+
363+
/// Get the next job given the page returned by the database
364+
#[must_use]
365+
pub fn next(&self, page: &Page<Session>) -> Option<Self> {
366+
if !page.has_next_page {
367+
return None;
368+
}
369+
370+
let last_edge = page.edges.last()?;
371+
Some(Self {
372+
threshold: self.threshold,
373+
after: Some(last_edge.id),
374+
})
375+
}
376+
}
377+
378+
impl InsertableJob for ExpireInactiveOAuthSessionsJob {
379+
const QUEUE_NAME: &'static str = "expire-inactive-oauth-sessions";
380+
}

crates/tasks/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod email;
2222
mod matrix;
2323
mod new_queue;
2424
mod recovery;
25+
mod sessions;
2526
mod user;
2627

2728
static METER: LazyLock<Meter> = LazyLock::new(|| {
@@ -128,6 +129,7 @@ pub async fn init(
128129
.register_handler::<mas_storage::queue::SendEmailAuthenticationCodeJob>()
129130
.register_handler::<mas_storage::queue::SyncDevicesJob>()
130131
.register_handler::<mas_storage::queue::VerifyEmailJob>()
132+
.register_handler::<mas_storage::queue::ExpireInactiveOAuthSessionsJob>()
131133
.add_schedule(
132134
"cleanup-expired-tokens",
133135
"0 0 * * * *".parse()?,

crates/tasks/src/sessions.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2025 New Vector Ltd.
2+
//
3+
// SPDX-License-Identifier: AGPL-3.0-only
4+
// Please see LICENSE in the repository root for full details.
5+
6+
use std::collections::HashSet;
7+
8+
use async_trait::async_trait;
9+
use chrono::Duration;
10+
use mas_storage::{
11+
oauth2::OAuth2SessionFilter,
12+
queue::{ExpireInactiveOAuthSessionsJob, QueueJobRepositoryExt, SyncDevicesJob},
13+
};
14+
15+
use crate::{
16+
new_queue::{JobContext, JobError, RunnableJob},
17+
State,
18+
};
19+
20+
#[async_trait]
21+
impl RunnableJob for ExpireInactiveOAuthSessionsJob {
22+
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
23+
let mut repo = state.repository().await.map_err(JobError::retry)?;
24+
let clock = state.clock();
25+
let mut rng = state.rng();
26+
let mut users_synced = HashSet::new();
27+
28+
// This delay is used to space out the device sync jobs
29+
// We add 10 seconds between each device sync, meaning that it will spread out
30+
// the syncs over ~16 minutes max if we get a full batch of 100 users
31+
let mut delay = Duration::minutes(1);
32+
33+
let filter = OAuth2SessionFilter::new()
34+
.with_last_active_before(self.threshold())
35+
.for_any_user()
36+
.active_only();
37+
38+
let pagination = self.pagination(100);
39+
40+
let page = repo
41+
.oauth2_session()
42+
.list(filter, pagination)
43+
.await
44+
.map_err(JobError::retry)?;
45+
46+
if let Some(job) = self.next(&page) {
47+
tracing::info!("Scheduling job to expire the next batch of inactive sessions");
48+
repo.queue_job()
49+
.schedule_job(&mut rng, &clock, job)
50+
.await
51+
.map_err(JobError::retry)?;
52+
}
53+
54+
for edge in page.edges {
55+
if let Some(user_id) = edge.user_id {
56+
let inserted = users_synced.insert(user_id);
57+
if inserted {
58+
tracing::info!(user.id = %user_id, "Scheduling devices sync for user");
59+
repo.queue_job()
60+
.schedule_job_later(
61+
&mut rng,
62+
&clock,
63+
SyncDevicesJob::new_for_id(user_id),
64+
clock.now() + delay,
65+
)
66+
.await
67+
.map_err(JobError::retry)?;
68+
delay += Duration::seconds(10);
69+
}
70+
}
71+
72+
repo.oauth2_session()
73+
.finish(&clock, edge)
74+
.await
75+
.map_err(JobError::retry)?;
76+
}
77+
78+
repo.save().await.map_err(JobError::retry)?;
79+
80+
Ok(())
81+
}
82+
}

0 commit comments

Comments
 (0)