Skip to content

Commit fa52f7f

Browse files
committed
Add MasWriter support and test for upstream OAuth provider links
1 parent 530c759 commit fa52f7f

7 files changed

+329
-8
lines changed

crates/syn2mas/.sqlx/query-d79fd99ebed9033711f96113005096c848ae87c43b6430246ef3b6a1dc6a7a32.json

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
INSERT INTO upstream_oauth_providers
2+
(
3+
upstream_oauth_provider_id,
4+
scope,
5+
client_id,
6+
token_endpoint_auth_method,
7+
created_at
8+
)
9+
VALUES
10+
(
11+
'00000000-0000-0000-0000-000000000004',
12+
'openid',
13+
'someClientId',
14+
'client_secret_basic',
15+
'2011-12-13 14:15:16Z'
16+
);

crates/syn2mas/src/mas_writer/mod.rs

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
use std::fmt::Display;
1111

1212
use chrono::{DateTime, Utc};
13-
use futures_util::{future::BoxFuture, TryStreamExt};
13+
use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
1414
use sqlx::{query, query_as, Executor, PgConnection};
1515
use thiserror::Error;
1616
use thiserror_ext::{Construct, ContextInto};
@@ -222,6 +222,14 @@ pub struct MasNewUnsupportedThreepid {
222222
pub created_at: DateTime<Utc>,
223223
}
224224

225+
pub struct MasNewUpstreamOauthLink {
226+
pub link_id: Uuid,
227+
pub user_id: Uuid,
228+
pub upstream_provider_id: Uuid,
229+
pub subject: String,
230+
pub created_at: DateTime<Utc>,
231+
}
232+
225233
/// The 'version' of the password hashing scheme used for passwords when they
226234
/// are migrated from Synapse to MAS.
227235
/// This is version 1, as in the previous syn2mas script.
@@ -234,6 +242,7 @@ pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
234242
"user_passwords",
235243
"user_emails",
236244
"user_unsupported_third_party_ids",
245+
"upstream_oauth_links",
237246
];
238247

239248
/// Detect whether a syn2mas migration has started on the given database.
@@ -700,8 +709,6 @@ impl<'conn> MasWriter<'conn> {
700709
created_ats.push(created_at);
701710
}
702711

703-
// `confirmed_at` is going to get removed in a future MAS release,
704-
// so just populate with `created_at`
705712
sqlx::query!(
706713
r#"
707714
INSERT INTO syn2mas__user_unsupported_third_party_ids
@@ -718,6 +725,55 @@ impl<'conn> MasWriter<'conn> {
718725
})
719726
}).await
720727
}
728+
729+
#[tracing::instrument(skip_all, level = Level::DEBUG)]
730+
pub fn write_upstream_oauth_links(
731+
&mut self,
732+
links: Vec<MasNewUpstreamOauthLink>,
733+
) -> BoxFuture<'_, Result<(), Error>> {
734+
if links.is_empty() {
735+
return async { Ok(()) }.boxed();
736+
}
737+
self.writer_pool.spawn_with_connection(move |conn| {
738+
Box::pin(async move {
739+
let mut link_ids: Vec<Uuid> = Vec::with_capacity(links.len());
740+
let mut user_ids: Vec<Uuid> = Vec::with_capacity(links.len());
741+
let mut upstream_provider_ids: Vec<Uuid> = Vec::with_capacity(links.len());
742+
let mut subjects: Vec<String> = Vec::with_capacity(links.len());
743+
let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(links.len());
744+
745+
for MasNewUpstreamOauthLink {
746+
link_id,
747+
user_id,
748+
upstream_provider_id,
749+
subject,
750+
created_at,
751+
} in links
752+
{
753+
link_ids.push(link_id);
754+
user_ids.push(user_id);
755+
upstream_provider_ids.push(upstream_provider_id);
756+
subjects.push(subject);
757+
created_ats.push(created_at);
758+
}
759+
760+
sqlx::query!(
761+
r#"
762+
INSERT INTO syn2mas__upstream_oauth_links
763+
(upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
764+
SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
765+
"#,
766+
&link_ids[..],
767+
&user_ids[..],
768+
&upstream_provider_ids[..],
769+
&subjects[..],
770+
&created_ats[..],
771+
).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
772+
773+
Ok(())
774+
})
775+
}).boxed()
776+
}
721777
}
722778

723779
// How many entries to buffer at once, before writing a batch of rows to the
@@ -727,6 +783,7 @@ impl<'conn> MasWriter<'conn> {
727783
// stream to two tables at once...)
728784
const WRITE_BUFFER_BATCH_SIZE: usize = 4096;
729785

786+
// TODO replace with just `MasWriteBuffer`
730787
pub struct MasUserWriteBuffer<'writer, 'conn> {
731788
users: Vec<MasNewUser>,
732789
passwords: Vec<MasNewUserPassword>,
@@ -786,6 +843,7 @@ impl<'writer, 'conn> MasUserWriteBuffer<'writer, 'conn> {
786843
}
787844
}
788845

846+
// TODO replace with just `MasWriteBuffer`
789847
pub struct MasThreepidWriteBuffer<'writer, 'conn> {
790848
email: Vec<MasNewEmailThreepid>,
791849
unsupported: Vec<MasNewUnsupportedThreepid>,
@@ -843,6 +901,60 @@ impl<'writer, 'conn> MasThreepidWriteBuffer<'writer, 'conn> {
843901
}
844902
}
845903

904+
/// A function that can accept and flush buffers from a `MasWriteBuffer`.
905+
/// Intended uses are the methods on `MasWriter` such as `write_users`.
906+
type WriteBufferFlusher<'conn, T> =
907+
for<'a> fn(&'a mut MasWriter<'conn>, Vec<T>) -> BoxFuture<'a, Result<(), Error>>;
908+
909+
/// A buffer for writing rows to the MAS database.
910+
/// Generic over the type of rows.
911+
///
912+
/// # Panics
913+
///
914+
/// Panics if dropped before `finish()` has been called.
915+
pub struct MasWriteBuffer<'conn, T> {
916+
rows: Vec<T>,
917+
flusher: WriteBufferFlusher<'conn, T>,
918+
finished: bool,
919+
}
920+
921+
impl<'conn, T> MasWriteBuffer<'conn, T> {
922+
pub fn new(flusher: WriteBufferFlusher<'conn, T>) -> Self {
923+
MasWriteBuffer {
924+
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
925+
flusher,
926+
finished: false,
927+
}
928+
}
929+
930+
pub async fn finish(mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
931+
self.finished = true;
932+
self.flush(writer).await?;
933+
Ok(())
934+
}
935+
936+
pub async fn flush(&mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
937+
let rows = std::mem::take(&mut self.rows);
938+
self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
939+
(self.flusher)(writer, rows).await?;
940+
Ok(())
941+
}
942+
943+
pub async fn write(&mut self, writer: &mut MasWriter<'conn>, row: T) -> Result<(), Error> {
944+
self.rows.push(row);
945+
if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
946+
self.flush(writer).await?;
947+
}
948+
Ok(())
949+
}
950+
}
951+
952+
impl<T> Drop for MasWriteBuffer<'_, T> {
953+
fn drop(&mut self) {
954+
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
955+
}
956+
}
957+
846958
#[cfg(test)]
847959
mod test {
848960
use std::collections::{BTreeMap, BTreeSet};
@@ -855,7 +967,8 @@ mod test {
855967

856968
use crate::{
857969
mas_writer::{
858-
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUser, MasNewUserPassword,
970+
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
971+
MasNewUserPassword,
859972
},
860973
LockedMasDatabase, MasWriter,
861974
};
@@ -1085,4 +1198,39 @@ mod test {
10851198

10861199
assert_db_snapshot!(&mut conn);
10871200
}
1201+
1202+
/// Tests writing a single user, with a link to an upstream provider.
1203+
/// There needs to be an upstream provider in the database already — in the
1204+
/// real migration, this is done by running a provider sync first.
1205+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
1206+
async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
1207+
let mut conn = pool.acquire().await.unwrap();
1208+
let mut writer = make_mas_writer(&pool, &mut conn).await;
1209+
1210+
writer
1211+
.write_users(vec![MasNewUser {
1212+
user_id: Uuid::from_u128(1u128),
1213+
username: "alice".to_owned(),
1214+
created_at: DateTime::default(),
1215+
locked_at: None,
1216+
can_request_admin: false,
1217+
}])
1218+
.await
1219+
.expect("failed to write user");
1220+
1221+
writer
1222+
.write_upstream_oauth_links(vec![MasNewUpstreamOauthLink {
1223+
user_id: Uuid::from_u128(1u128),
1224+
link_id: Uuid::from_u128(3u128),
1225+
upstream_provider_id: Uuid::from_u128(4u128),
1226+
subject: "12345.67890".to_owned(),
1227+
created_at: DateTime::default(),
1228+
}])
1229+
.await
1230+
.expect("failed to write link");
1231+
1232+
writer.finish().await.expect("failed to finish MasWriter");
1233+
1234+
assert_db_snapshot!(&mut conn);
1235+
}
10881236
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
---
2+
source: crates/syn2mas/src/mas_writer/mod.rs
3+
expression: db_snapshot
4+
---
5+
upstream_oauth_links:
6+
- created_at: "1970-01-01 00:00:00+00"
7+
human_account_name: ~
8+
subject: "12345.67890"
9+
upstream_oauth_link_id: 00000000-0000-0000-0000-000000000003
10+
upstream_oauth_provider_id: 00000000-0000-0000-0000-000000000004
11+
user_id: 00000000-0000-0000-0000-000000000001
12+
upstream_oauth_providers:
13+
- additional_parameters: ~
14+
authorization_endpoint_override: ~
15+
brand_name: ~
16+
claims_imports: "{}"
17+
client_id: someClientId
18+
created_at: "2011-12-13 14:15:16+00"
19+
disabled_at: ~
20+
discovery_mode: oidc
21+
encrypted_client_secret: ~
22+
fetch_userinfo: "false"
23+
human_name: ~
24+
id_token_signed_response_alg: RS256
25+
issuer: ~
26+
jwks_uri_override: ~
27+
pkce_mode: auto
28+
response_mode: query
29+
scope: openid
30+
token_endpoint_auth_method: client_secret_basic
31+
token_endpoint_override: ~
32+
token_endpoint_signing_alg: ~
33+
upstream_oauth_provider_id: 00000000-0000-0000-0000-000000000004
34+
userinfo_endpoint_override: ~
35+
userinfo_signed_response_alg: ~
36+
users:
37+
- can_request_admin: "false"
38+
created_at: "1970-01-01 00:00:00+00"
39+
locked_at: ~
40+
primary_user_email_id: ~
41+
user_id: 00000000-0000-0000-0000-000000000001
42+
username: alice

crates/syn2mas/src/mas_writer/syn2mas_revert_temporary_tables.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ ALTER TABLE syn2mas__users RENAME TO users;
1212
ALTER TABLE syn2mas__user_passwords RENAME TO user_passwords;
1313
ALTER TABLE syn2mas__user_emails RENAME TO user_emails;
1414
ALTER TABLE syn2mas__user_unsupported_third_party_ids RENAME TO user_unsupported_third_party_ids;
15+
ALTER TABLE syn2mas__upstream_oauth_links RENAME TO upstream_oauth_links;

crates/syn2mas/src/mas_writer/syn2mas_temporary_tables.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ ALTER TABLE users RENAME TO syn2mas__users;
4141
ALTER TABLE user_passwords RENAME TO syn2mas__user_passwords;
4242
ALTER TABLE user_emails RENAME TO syn2mas__user_emails;
4343
ALTER TABLE user_unsupported_third_party_ids RENAME TO syn2mas__user_unsupported_third_party_ids;
44+
ALTER TABLE upstream_oauth_links RENAME TO syn2mas__upstream_oauth_links;

0 commit comments

Comments
 (0)