Skip to content

Commit cea130e

Browse files
committed
Pass claims as context instead of globals during attr mapping
This is in preparation for also fetching the claims from the userinfo endpoint
1 parent 93bbfab commit cea130e

File tree

3 files changed

+121
-53
lines changed

3 files changed

+121
-53
lines changed

crates/handlers/src/upstream_oauth2/callback.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ use thiserror::Error;
3030
use ulid::Ulid;
3131

3232
use super::{
33-
cache::LazyProviderInfos, client_credentials_for_provider, template::environment,
33+
cache::LazyProviderInfos,
34+
client_credentials_for_provider,
35+
template::{environment, AttributeMappingContext},
3436
UpstreamSessionsCookie,
3537
};
3638
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
@@ -269,15 +271,13 @@ pub(crate) async fn handler(
269271

270272
let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
271273

272-
let env = {
273-
let mut env = environment();
274-
env.add_global("user", minijinja::Value::from_serialize(&id_token));
275-
env.add_global(
276-
"extra_callback_parameters",
277-
minijinja::Value::from_serialize(&extra_callback_parameters),
278-
);
279-
env
280-
};
274+
let mut context = AttributeMappingContext::new().with_id_token_claims(id_token);
275+
if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
276+
context = context.with_extra_callback_parameters(extra_callback_parameters);
277+
}
278+
let context = context.build();
279+
280+
let env = environment();
281281

282282
let template = provider
283283
.claims_imports
@@ -286,7 +286,7 @@ pub(crate) async fn handler(
286286
.as_deref()
287287
.unwrap_or("{{ user.sub }}");
288288
let subject = env
289-
.render_str(template, ())
289+
.render_str(template, context)
290290
.map_err(RouteError::ExtractSubject)?;
291291

292292
if subject.is_empty() {

crates/handlers/src/upstream_oauth2/link.rs

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ use thiserror::Error;
3838
use tracing::warn;
3939
use ulid::Ulid;
4040

41-
use super::{template::environment, UpstreamSessionsCookie};
41+
use super::{
42+
template::{environment, AttributeMappingContext},
43+
UpstreamSessionsCookie,
44+
};
4245
use crate::{
4346
impl_from_error_for_route, views::shared::OptionalPostAuthAction, PreferredLanguage, SiteConfig,
4447
};
@@ -130,9 +133,10 @@ impl IntoResponse for RouteError {
130133
fn render_attribute_template(
131134
environment: &Environment,
132135
template: &str,
136+
context: &minijinja::Value,
133137
required: bool,
134138
) -> Result<Option<String>, RouteError> {
135-
match environment.render_str(template, ()) {
139+
match environment.render_str(template, context) {
136140
Ok(value) if value.is_empty() => {
137141
if required {
138142
return Err(RouteError::RequiredAttributeEmpty {
@@ -320,32 +324,27 @@ pub(crate) async fn get(
320324
(None, None) => {
321325
// Session not linked and used not logged in: suggest creating an
322326
// account or logging in an existing user
323-
let id_token = upstream_session
324-
.id_token()
325-
.map(Jwt::<'_, minijinja::Value>::try_from)
326-
.transpose()?;
327+
let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
327328

328329
let provider = repo
329330
.upstream_oauth_provider()
330331
.lookup(link.provider_id)
331332
.await?
332333
.ok_or(RouteError::ProviderNotFound)?;
333334

334-
let payload = id_token
335-
.map(|id_token| id_token.into_parts().1)
336-
.unwrap_or_default();
337-
338335
let ctx = UpstreamRegister::default();
339336

340-
let env = {
341-
let mut e = environment();
342-
e.add_global("user", payload);
343-
e.add_global(
344-
"extra_callback_parameters",
345-
minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()),
346-
);
347-
e
348-
};
337+
let env = environment();
338+
339+
let mut context = AttributeMappingContext::new();
340+
if let Some(id_token) = id_token {
341+
let (_, payload) = id_token.into_parts();
342+
context = context.with_id_token_claims(payload);
343+
}
344+
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
345+
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
346+
}
347+
let context = context.build();
349348

350349
let ctx = if provider.claims_imports.displayname.ignore() {
351350
ctx
@@ -360,6 +359,7 @@ pub(crate) async fn get(
360359
match render_attribute_template(
361360
&env,
362361
template,
362+
&context,
363363
provider.claims_imports.displayname.is_required(),
364364
)? {
365365
Some(value) => ctx
@@ -381,6 +381,7 @@ pub(crate) async fn get(
381381
match render_attribute_template(
382382
&env,
383383
template,
384+
&context,
384385
provider.claims_imports.email.is_required(),
385386
)? {
386387
Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
@@ -401,6 +402,7 @@ pub(crate) async fn get(
401402
match render_attribute_template(
402403
&env,
403404
template,
405+
&context,
404406
provider.claims_imports.localpart.is_required(),
405407
)? {
406408
Some(localpart) => {
@@ -561,37 +563,31 @@ pub(crate) async fn post(
561563
let import_display_name = import_display_name.is_some();
562564
let accept_terms = accept_terms.is_some();
563565

564-
let id_token = upstream_session
565-
.id_token()
566-
.map(Jwt::<'_, minijinja::Value>::try_from)
567-
.transpose()?;
566+
let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
568567

569568
let provider = repo
570569
.upstream_oauth_provider()
571570
.lookup(link.provider_id)
572571
.await?
573572
.ok_or(RouteError::ProviderNotFound)?;
574573

575-
let payload = id_token
576-
.map(|id_token| id_token.into_parts().1)
577-
.unwrap_or_default();
574+
// Let's try to import the claims from the ID token
575+
let env = environment();
578576

579-
// Is the email verified according to the upstream provider?
580-
let provider_email_verified = payload
581-
.get_item(&minijinja::Value::from("email_verified"))
582-
.map(|v| v.is_true())
583-
.unwrap_or(false);
577+
let mut context = AttributeMappingContext::new();
578+
if let Some(id_token) = id_token {
579+
let (_, payload) = id_token.into_parts();
580+
context = context.with_id_token_claims(payload);
581+
}
582+
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
583+
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
584+
}
585+
let context = context.build();
584586

585-
// Let's try to import the claims from the ID token
586-
let env = {
587-
let mut e = environment();
588-
e.add_global("user", payload);
589-
e.add_global(
590-
"extra_callback_parameters",
591-
minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()),
592-
);
593-
e
594-
};
587+
// Is the email verified according to the upstream provider?
588+
let provider_email_verified = env
589+
.render_str("{{ user.email_verified | string }}", &context)
590+
.map_or(false, |v| v == "true");
595591

596592
// Create a template context in case we need to re-render because of an error
597593
let ctx = UpstreamRegister::default();
@@ -611,6 +607,7 @@ pub(crate) async fn post(
611607
render_attribute_template(
612608
&env,
613609
template,
610+
&context,
614611
provider.claims_imports.displayname.is_required(),
615612
)?
616613
} else {
@@ -637,6 +634,7 @@ pub(crate) async fn post(
637634
render_attribute_template(
638635
&env,
639636
template,
637+
&context,
640638
provider.claims_imports.email.is_required(),
641639
)?
642640
} else {
@@ -660,6 +658,7 @@ pub(crate) async fn post(
660658
render_attribute_template(
661659
&env,
662660
template,
661+
&context,
663662
provider.claims_imports.email.is_required(),
664663
)?
665664
} else {

crates/handlers/src/upstream_oauth2/template.rs

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,76 @@
77
use std::{collections::HashMap, sync::Arc};
88

99
use base64ct::{Base64, Base64Unpadded, Base64Url, Base64UrlUnpadded, Encoding};
10-
use minijinja::{Environment, Error, ErrorKind, Value};
10+
use minijinja::{
11+
value::{Enumerator, Object},
12+
Environment, Error, ErrorKind, Value,
13+
};
14+
15+
/// Context passed to the attribute mapping template
16+
///
17+
/// The variables available in the template are:
18+
/// - `user`: claims for the user, currently from the ID token. Later, we'll
19+
/// also allow importing from the userinfo endpoint
20+
/// - `id_token_claims`: claims from the ID token
21+
/// - `extra_callback_parameters`: extra parameters passed to the callback
22+
#[derive(Debug, Default)]
23+
pub(crate) struct AttributeMappingContext {
24+
id_token_claims: Option<HashMap<String, serde_json::Value>>,
25+
extra_callback_parameters: Option<serde_json::Value>,
26+
}
27+
28+
impl AttributeMappingContext {
29+
pub fn new() -> Self {
30+
Self::default()
31+
}
32+
33+
pub fn with_id_token_claims(
34+
mut self,
35+
id_token_claims: HashMap<String, serde_json::Value>,
36+
) -> Self {
37+
self.id_token_claims = Some(id_token_claims);
38+
self
39+
}
40+
41+
pub fn with_extra_callback_parameters(
42+
mut self,
43+
extra_callback_parameters: serde_json::Value,
44+
) -> Self {
45+
self.extra_callback_parameters = Some(extra_callback_parameters);
46+
self
47+
}
48+
49+
pub fn build(self) -> Value {
50+
Value::from_object(self)
51+
}
52+
}
53+
54+
impl Object for AttributeMappingContext {
55+
fn get_value(self: &Arc<Self>, name: &Value) -> Option<Value> {
56+
match name.as_str()? {
57+
"user" | "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
58+
"extra_callback_parameters" => self
59+
.extra_callback_parameters
60+
.as_ref()
61+
.map(Value::from_serialize),
62+
_ => None,
63+
}
64+
}
65+
66+
fn enumerate(self: &Arc<Self>) -> Enumerator {
67+
match (
68+
self.id_token_claims.is_some(),
69+
self.extra_callback_parameters.is_some(),
70+
) {
71+
(true, true) => {
72+
Enumerator::Str(&["user", "id_token_claims", "extra_callback_parameters"])
73+
}
74+
(true, false) => Enumerator::Str(&["user", "id_token_claims"]),
75+
(false, true) => Enumerator::Str(&["extra_callback_parameters"]),
76+
(false, false) => Enumerator::Str(&["user"]),
77+
}
78+
}
79+
}
1180

1281
fn b64decode(value: &str) -> Result<Value, Error> {
1382
// We're not too concerned about the performance of this filter, so we'll just

0 commit comments

Comments
 (0)