Skip to content

Commit 98f1300

Browse files
committed
Record extra query parameters during upstream callback
And make them available in the templates. This is useful to get the user display name for Sign-in with Apple
1 parent a4421aa commit 98f1300

File tree

12 files changed

+137
-25
lines changed

12 files changed

+137
-25
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/data-model/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ workspace = true
1515
chrono.workspace = true
1616
thiserror.workspace = true
1717
serde.workspace = true
18+
serde_json.workspace = true
1819
url.workspace = true
1920
crc = "3.2.1"
2021
ulid.workspace = true

crates/data-model/src/upstream_oauth2/session.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ pub enum UpstreamOAuthAuthorizationSessionState {
1919
completed_at: DateTime<Utc>,
2020
link_id: Ulid,
2121
id_token: Option<String>,
22+
extra_callback_parameters: Option<serde_json::Value>,
2223
},
2324
Consumed {
2425
completed_at: DateTime<Utc>,
2526
consumed_at: DateTime<Utc>,
2627
link_id: Ulid,
2728
id_token: Option<String>,
29+
extra_callback_parameters: Option<serde_json::Value>,
2830
},
2931
}
3032

@@ -42,12 +44,14 @@ impl UpstreamOAuthAuthorizationSessionState {
4244
completed_at: DateTime<Utc>,
4345
link: &UpstreamOAuthLink,
4446
id_token: Option<String>,
47+
extra_callback_parameters: Option<serde_json::Value>,
4548
) -> Result<Self, InvalidTransitionError> {
4649
match self {
4750
Self::Pending => Ok(Self::Completed {
4851
completed_at,
4952
link_id: link.id,
5053
id_token,
54+
extra_callback_parameters,
5155
}),
5256
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
5357
}
@@ -67,11 +71,13 @@ impl UpstreamOAuthAuthorizationSessionState {
6771
completed_at,
6872
link_id,
6973
id_token,
74+
extra_callback_parameters,
7075
} => Ok(Self::Consumed {
7176
completed_at,
7277
link_id,
7378
consumed_at,
7479
id_token,
80+
extra_callback_parameters,
7581
}),
7682
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
7783
}
@@ -124,6 +130,27 @@ impl UpstreamOAuthAuthorizationSessionState {
124130
}
125131
}
126132

133+
/// Get the extra query parameters that were sent to the upstream provider.
134+
///
135+
/// Returns `None` if the upstream OAuth 2.0 authorization session state is
136+
/// not [`Pending`].
137+
///
138+
/// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
139+
#[must_use]
140+
pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> {
141+
match self {
142+
Self::Pending => None,
143+
Self::Completed {
144+
extra_callback_parameters,
145+
..
146+
}
147+
| Self::Consumed {
148+
extra_callback_parameters,
149+
..
150+
} => extra_callback_parameters.as_ref(),
151+
}
152+
}
153+
127154
/// Get the time at which the upstream OAuth 2.0 authorization session was
128155
/// consumed.
129156
///
@@ -201,8 +228,11 @@ impl UpstreamOAuthAuthorizationSession {
201228
completed_at: DateTime<Utc>,
202229
link: &UpstreamOAuthLink,
203230
id_token: Option<String>,
231+
extra_callback_parameters: Option<serde_json::Value>,
204232
) -> Result<Self, InvalidTransitionError> {
205-
self.state = self.state.complete(completed_at, link, id_token)?;
233+
self.state =
234+
self.state
235+
.complete(completed_at, link, id_token, extra_callback_parameters)?;
206236
Ok(self)
207237
}
208238

crates/handlers/src/upstream_oauth2/callback.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ pub struct Params {
4848
enum CodeOrError {
4949
Code {
5050
code: String,
51+
52+
#[serde(flatten)]
53+
extra_callback_parameters: Option<serde_json::Value>,
5154
},
5255
Error {
5356
error: ClientErrorCode,
@@ -201,7 +204,7 @@ pub(crate) async fn handler(
201204
}
202205

203206
// Let's extract the code from the params, and return if there was an error
204-
let code = match params.code_or_error {
207+
let (code, extra_callback_parameters) = match params.code_or_error {
205208
CodeOrError::Error {
206209
error,
207210
error_description,
@@ -212,7 +215,10 @@ pub(crate) async fn handler(
212215
error_description,
213216
})
214217
}
215-
CodeOrError::Code { code } => code,
218+
CodeOrError::Code {
219+
code,
220+
extra_callback_parameters,
221+
} => (code, extra_callback_parameters),
216222
};
217223

218224
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
@@ -266,6 +272,10 @@ pub(crate) async fn handler(
266272
let env = {
267273
let mut env = environment();
268274
env.add_global("user", minijinja::Value::from_serialize(&id_token));
275+
env.add_global(
276+
"extra_callback_parameters",
277+
minijinja::Value::from_serialize(&extra_callback_parameters),
278+
);
269279
env
270280
};
271281

@@ -299,7 +309,13 @@ pub(crate) async fn handler(
299309

300310
let session = repo
301311
.upstream_oauth_session()
302-
.complete_with_link(&clock, session, &link, response.id_token)
312+
.complete_with_link(
313+
&clock,
314+
session,
315+
&link,
316+
response.id_token,
317+
extra_callback_parameters,
318+
)
303319
.await?;
304320

305321
let cookie_jar = sessions_cookie

crates/handlers/src/upstream_oauth2/link.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,10 @@ pub(crate) async fn get(
340340
let env = {
341341
let mut e = environment();
342342
e.add_global("user", payload);
343+
e.add_global(
344+
"extra_callback_parameters",
345+
minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()),
346+
);
343347
e
344348
};
345349

@@ -582,6 +586,10 @@ pub(crate) async fn post(
582586
let env = {
583587
let mut e = environment();
584588
e.add_global("user", payload);
589+
e.add_global(
590+
"extra_callback_parameters",
591+
minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()),
592+
);
585593
e
586594
};
587595

@@ -945,7 +953,13 @@ mod tests {
945953

946954
let session = repo
947955
.upstream_oauth_session()
948-
.complete_with_link(&state.clock, session, &link, Some(id_token.into_string()))
956+
.complete_with_link(
957+
&state.clock,
958+
session,
959+
&link,
960+
Some(id_token.into_string()),
961+
None,
962+
)
949963
.await
950964
.unwrap();
951965

crates/handlers/src/upstream_oauth2/template.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,18 @@ fn string(value: &Value) -> String {
6868
value.to_string()
6969
}
7070

71+
fn from_json(value: &str) -> Result<Value, minijinja::Error> {
72+
let value: serde_json::Value = serde_json::from_str(value).map_err(|e| {
73+
minijinja::Error::new(
74+
minijinja::ErrorKind::InvalidOperation,
75+
"Failed to decode JSON",
76+
)
77+
.with_source(e)
78+
})?;
79+
80+
Ok(Value::from_serialize(value))
81+
}
82+
7183
pub fn environment() -> Environment<'static> {
7284
let mut env = Environment::new();
7385

@@ -77,6 +89,7 @@ pub fn environment() -> Environment<'static> {
7789
env.add_filter("b64encode", b64encode);
7890
env.add_filter("tlvdecode", tlvdecode);
7991
env.add_filter("string", string);
92+
env.add_filter("from_json", from_json);
8093

8194
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
8295

crates/storage-pg/.sqlx/query-b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64.json renamed to crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/storage-pg/.sqlx/query-67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c.json renamed to crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json

Lines changed: 10 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
-- Copyright 2024 New Vector Ltd.
2+
--
3+
-- SPDX-License-Identifier: AGPL-3.0-only
4+
-- Please see LICENSE in the repository root for full details.
5+
6+
-- Add a column to the upstream_oauth_authorization_sessions table to store
7+
-- extra query parameters
8+
ALTER TABLE "upstream_oauth_authorization_sessions"
9+
ADD COLUMN "extra_callback_parameters" JSONB;

crates/storage-pg/src/upstream_oauth2/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ mod tests {
145145

146146
let session = repo
147147
.upstream_oauth_session()
148-
.complete_with_link(&clock, session, &link, None)
148+
.complete_with_link(&clock, session, &link, None, None)
149149
.await
150150
.unwrap();
151151
// Reload the session

0 commit comments

Comments
 (0)