Skip to content

Commit 33fdfc6

Browse files
committed
Allow filtering upstream sessions by sub and sid claims
1 parent 5db877e commit 33fdfc6

File tree

5 files changed

+154
-9
lines changed

5 files changed

+154
-9
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- no-transaction
2+
-- Copyright 2025 New Vector Ltd.
3+
--
4+
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5+
-- Please see LICENSE in the repository root for full details.
6+
7+
-- We'll be requesting authorization sessions by provider, sub and sid, so we'll
8+
-- need to index those columns
9+
CREATE INDEX CONCURRENTLY IF NOT EXISTS
10+
upstream_oauth_authorization_sessions_sub_sid_idx
11+
ON upstream_oauth_authorization_sessions (
12+
upstream_oauth_provider_id,
13+
(id_token_claims->>'sub'),
14+
(id_token_claims->>'sid')
15+
);
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- no-transaction
2+
-- Copyright 2025 New Vector Ltd.
3+
--
4+
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5+
-- Please see LICENSE in the repository root for full details.
6+
7+
-- We'll be requesting authorization sessions by provider, sub and sid, so we'll
8+
-- need to index those columns
9+
CREATE INDEX CONCURRENTLY IF NOT EXISTS
10+
upstream_oauth_authorization_sessions_sid_sub_idx
11+
ON upstream_oauth_authorization_sessions (
12+
upstream_oauth_provider_id,
13+
(id_token_claims->>'sid'),
14+
(id_token_claims->>'sub')
15+
);

crates/storage-pg/src/upstream_oauth2/mod.rs

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,15 +499,45 @@ mod tests {
499499
0
500500
);
501501

502+
let mut links = Vec::with_capacity(3);
503+
for subject in ["alice", "bob", "charlie"] {
504+
let link = repo
505+
.upstream_oauth_link()
506+
.add(&mut rng, &clock, &provider, subject.to_owned(), None)
507+
.await
508+
.unwrap();
509+
links.push(link);
510+
}
511+
502512
let mut ids = Vec::with_capacity(20);
513+
let sids = ["one", "two"].into_iter().cycle();
503514
// Create 20 sessions
504-
for idx in 0..20 {
515+
for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
505516
let state = format!("state-{idx}");
506517
let session = repo
507518
.upstream_oauth_session()
508519
.add(&mut rng, &clock, &provider, state, None, None)
509520
.await
510521
.unwrap();
522+
let id_token_claims = serde_json::json!({
523+
"sub": link.subject,
524+
"sid": sid,
525+
"aud": provider.client_id,
526+
"iss": "https://example.com/",
527+
});
528+
let session = repo
529+
.upstream_oauth_session()
530+
.complete_with_link(
531+
&clock,
532+
session,
533+
link,
534+
None,
535+
Some(id_token_claims),
536+
None,
537+
None,
538+
)
539+
.await
540+
.unwrap();
511541
ids.push(session.id);
512542
clock.advance(Duration::microseconds(10 * 1000 * 1000));
513543
}
@@ -577,5 +607,41 @@ mod tests {
577607
assert!(!page.has_next_page);
578608
let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
579609
assert_eq!(&edge_ids, &ids[6..11]);
610+
611+
// Check the sub/sid filters
612+
assert_eq!(
613+
repo.upstream_oauth_session()
614+
.count(filter.with_sub_claim("alice").with_sid_claim("one"))
615+
.await
616+
.unwrap(),
617+
4
618+
);
619+
assert_eq!(
620+
repo.upstream_oauth_session()
621+
.count(filter.with_sub_claim("bob").with_sid_claim("two"))
622+
.await
623+
.unwrap(),
624+
4
625+
);
626+
627+
let page = repo
628+
.upstream_oauth_session()
629+
.list(
630+
filter.with_sub_claim("alice").with_sid_claim("one"),
631+
Pagination::first(10),
632+
)
633+
.await
634+
.unwrap();
635+
assert_eq!(page.edges.len(), 4);
636+
for edge in page.edges {
637+
assert_eq!(
638+
edge.id_token_claims().unwrap().get("sub").unwrap().as_str(),
639+
Some("alice")
640+
);
641+
assert_eq!(
642+
edge.id_token_claims().unwrap().get("sid").unwrap().as_str(),
643+
Some("one")
644+
);
645+
}
580646
}
581647
}

crates/storage-pg/src/upstream_oauth2/session.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use mas_storage::{
1515
upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
1616
};
1717
use rand::RngCore;
18-
use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
18+
use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
1919
use sea_query_binder::SqlxBinder;
2020
use sqlx::PgConnection;
2121
use ulid::Ulid;
@@ -31,13 +31,30 @@ use crate::{
3131

3232
impl Filter for UpstreamOAuthSessionFilter<'_> {
3333
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
34-
sea_query::Condition::all().add_option(self.provider().map(|provider| {
35-
Expr::col((
36-
UpstreamOAuthAuthorizationSessions::Table,
37-
UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
38-
))
39-
.eq(Uuid::from(provider.id))
40-
}))
34+
sea_query::Condition::all()
35+
.add_option(self.provider().map(|provider| {
36+
Expr::col((
37+
UpstreamOAuthAuthorizationSessions::Table,
38+
UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
39+
))
40+
.eq(Uuid::from(provider.id))
41+
}))
42+
.add_option(self.sub_claim().map(|sub| {
43+
Expr::col((
44+
UpstreamOAuthAuthorizationSessions::Table,
45+
UpstreamOAuthAuthorizationSessions::IdTokenClaims,
46+
))
47+
.cast_json_field("sub")
48+
.eq(sub)
49+
}))
50+
.add_option(self.sid_claim().map(|sid| {
51+
Expr::col((
52+
UpstreamOAuthAuthorizationSessions::Table,
53+
UpstreamOAuthAuthorizationSessions::IdTokenClaims,
54+
))
55+
.cast_json_field("sid")
56+
.eq(sid)
57+
}))
4158
}
4259
}
4360

crates/storage/src/upstream_oauth2/session.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use crate::{Clock, Pagination, pagination::Page, repository_impl};
1515
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
1616
pub struct UpstreamOAuthSessionFilter<'a> {
1717
provider: Option<&'a UpstreamOAuthProvider>,
18+
sub_claim: Option<&'a str>,
19+
sid_claim: Option<&'a str>,
1820
}
1921

2022
impl<'a> UpstreamOAuthSessionFilter<'a> {
@@ -38,6 +40,36 @@ impl<'a> UpstreamOAuthSessionFilter<'a> {
3840
pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
3941
self.provider
4042
}
43+
44+
/// Set the `sub` claim to filter by
45+
#[must_use]
46+
pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
47+
self.sub_claim = Some(sub_claim);
48+
self
49+
}
50+
51+
/// Get the `sub` claim filter
52+
///
53+
/// Returns [`None`] if no filter was set
54+
#[must_use]
55+
pub fn sub_claim(&self) -> Option<&str> {
56+
self.sub_claim
57+
}
58+
59+
/// Set the `sid` claim to filter by
60+
#[must_use]
61+
pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
62+
self.sid_claim = Some(sid_claim);
63+
self
64+
}
65+
66+
/// Get the `sid` claim filter
67+
///
68+
/// Returns [`None`] if no filter was set
69+
#[must_use]
70+
pub fn sid_claim(&self) -> Option<&str> {
71+
self.sid_claim
72+
}
4173
}
4274

4375
/// An [`UpstreamOAuthSessionRepository`] helps interacting with

0 commit comments

Comments
 (0)