Skip to content

Commit 04e6960

Browse files
committed
Avoid using SameSite=None by re-submitting incoming form data
1 parent ac70632 commit 04e6960

File tree

10 files changed

+113
-47
lines changed

10 files changed

+113
-47
lines changed

crates/axum-utils/src/cookies.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,7 @@ impl CookieOption {
101101
cookie.set_http_only(true);
102102
cookie.set_secure(self.secure());
103103
cookie.set_path(self.path().to_owned());
104-
105-
// The `form_post` callback requires that, as it means a 3rd party origin will
106-
// POST to MAS. This is presumably fine, as our forms are protected with a CSRF
107-
// token
108-
cookie.set_same_site(if self.secure() {
109-
SameSite::None
110-
} else {
111-
SameSite::Lax
112-
});
104+
cookie.set_same_site(SameSite::Lax);
113105
cookie
114106
}
115107
}

crates/handlers/src/oauth2/authorization/callback.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::collections::HashMap;
1010

1111
use axum::response::{Html, IntoResponse, Redirect, Response};
1212
use mas_data_model::AuthorizationGrant;
13+
use mas_i18n::DataLocale;
1314
use mas_templates::{FormPostContext, Templates};
1415
use oauth2_types::requests::ResponseMode;
1516
use serde::Serialize;
@@ -103,6 +104,7 @@ impl CallbackDestination {
103104
pub async fn go<T: Serialize + Send + Sync>(
104105
self,
105106
templates: &Templates,
107+
locale: &DataLocale,
106108
params: T,
107109
) -> Result<Response, CallbackDestinationError> {
108110
#[derive(Serialize)]
@@ -155,7 +157,7 @@ impl CallbackDestination {
155157
state,
156158
params,
157159
};
158-
let ctx = FormPostContext::new(redirect_uri, merged);
160+
let ctx = FormPostContext::new_for_url(redirect_uri, merged).with_language(locale);
159161
let rendered = templates.render_form_post(&ctx)?;
160162
Ok(Html(rendered).into_response())
161163
}

crates/handlers/src/oauth2/authorization/complete.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub(crate) async fn get(
141141
.await
142142
{
143143
Ok(params) => {
144-
let res = callback_destination.go(&templates, params).await?;
144+
let res = callback_destination.go(&templates, &locale, params).await?;
145145
Ok((cookie_jar, res).into_response())
146146
}
147147
Err(GrantCompletionError::RequiresReauth) => Ok((

crates/handlers/src/oauth2/authorization/mod.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ pub(crate) async fn get(
170170
let res: Result<Response, RouteError> = ({
171171
let templates = templates.clone();
172172
let callback_destination = callback_destination.clone();
173+
let locale = locale.clone();
173174
async move {
174175
let maybe_session = session_info.load_session(&mut repo).await?;
175176
let prompt = params.auth.prompt.as_deref().unwrap_or_default();
@@ -180,6 +181,7 @@ pub(crate) async fn get(
180181
return Ok(callback_destination
181182
.go(
182183
&templates,
184+
&locale,
183185
ClientError::from(ClientErrorCode::RequestNotSupported),
184186
)
185187
.await?);
@@ -189,6 +191,7 @@ pub(crate) async fn get(
189191
return Ok(callback_destination
190192
.go(
191193
&templates,
194+
&locale,
192195
ClientError::from(ClientErrorCode::RequestUriNotSupported),
193196
)
194197
.await?);
@@ -200,6 +203,7 @@ pub(crate) async fn get(
200203
return Ok(callback_destination
201204
.go(
202205
&templates,
206+
&locale,
203207
ClientError::from(ClientErrorCode::UnsupportedResponseType),
204208
)
205209
.await?);
@@ -211,6 +215,7 @@ pub(crate) async fn get(
211215
return Ok(callback_destination
212216
.go(
213217
&templates,
218+
&locale,
214219
ClientError::from(ClientErrorCode::UnauthorizedClient),
215220
)
216221
.await?);
@@ -220,6 +225,7 @@ pub(crate) async fn get(
220225
return Ok(callback_destination
221226
.go(
222227
&templates,
228+
&locale,
223229
ClientError::from(ClientErrorCode::RegistrationNotSupported),
224230
)
225231
.await?);
@@ -230,6 +236,7 @@ pub(crate) async fn get(
230236
return Ok(callback_destination
231237
.go(
232238
&templates,
239+
&locale,
233240
ClientError::from(ClientErrorCode::LoginRequired),
234241
)
235242
.await?);
@@ -241,6 +248,7 @@ pub(crate) async fn get(
241248
return Ok(callback_destination
242249
.go(
243250
&templates,
251+
&locale,
244252
ClientError::from(ClientErrorCode::UnauthorizedClient),
245253
)
246254
.await?);
@@ -266,6 +274,7 @@ pub(crate) async fn get(
266274
return Ok(callback_destination
267275
.go(
268276
&templates,
277+
&locale,
269278
ClientError::from(ClientErrorCode::InvalidRequest),
270279
)
271280
.await?);
@@ -350,11 +359,12 @@ pub(crate) async fn get(
350359
)
351360
.await
352361
{
353-
Ok(params) => callback_destination.go(&templates, params).await?,
362+
Ok(params) => callback_destination.go(&templates, &locale, params).await?,
354363
Err(GrantCompletionError::RequiresConsent) => {
355364
callback_destination
356365
.go(
357366
&templates,
367+
&locale,
358368
ClientError::from(ClientErrorCode::ConsentRequired),
359369
)
360370
.await?
@@ -363,13 +373,14 @@ pub(crate) async fn get(
363373
callback_destination
364374
.go(
365375
&templates,
376+
&locale,
366377
ClientError::from(ClientErrorCode::InteractionRequired),
367378
)
368379
.await?
369380
}
370381
Err(GrantCompletionError::PolicyViolation(_grant, _res)) => {
371382
callback_destination
372-
.go(&templates, ClientError::from(ClientErrorCode::AccessDenied))
383+
.go(&templates, &locale, ClientError::from(ClientErrorCode::AccessDenied))
373384
.await?
374385
}
375386
Err(GrantCompletionError::Internal(e)) => {
@@ -400,7 +411,7 @@ pub(crate) async fn get(
400411
)
401412
.await
402413
{
403-
Ok(params) => callback_destination.go(&templates, params).await?,
414+
Ok(params) => callback_destination.go(&templates, &locale, params).await?,
404415
Err(GrantCompletionError::RequiresConsent) => {
405416
url_builder.redirect(&mas_router::Consent(grant_id)).into_response()
406417
}
@@ -440,7 +451,11 @@ pub(crate) async fn get(
440451
Err(err) => {
441452
tracing::error!(%err);
442453
callback_destination
443-
.go(&templates, ClientError::from(ClientErrorCode::ServerError))
454+
.go(
455+
&templates,
456+
&locale,
457+
ClientError::from(ClientErrorCode::ServerError),
458+
)
444459
.await?
445460
}
446461
};

crates/handlers/src/upstream_oauth2/callback.rs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
use axum::{
88
extract::{Path, Query, State},
9-
response::IntoResponse,
9+
response::{IntoResponse, Response},
1010
Form,
1111
};
12+
use axum_extra::response::Html;
1213
use hyper::StatusCode;
1314
use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
1415
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
@@ -24,8 +25,9 @@ use mas_storage::{
2425
},
2526
BoxClock, BoxRepository, BoxRng, Clock,
2627
};
28+
use mas_templates::{FormPostContext, Templates};
2729
use oauth2_types::errors::ClientErrorCode;
28-
use serde::Deserialize;
30+
use serde::{Deserialize, Serialize};
2931
use thiserror::Error;
3032
use ulid::Ulid;
3133

@@ -35,17 +37,17 @@ use super::{
3537
template::{environment, AttributeMappingContext},
3638
UpstreamSessionsCookie,
3739
};
38-
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
40+
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache, PreferredLanguage};
3941

40-
#[derive(Deserialize)]
42+
#[derive(Serialize, Deserialize)]
4143
pub struct Params {
4244
state: String,
4345

4446
#[serde(flatten)]
4547
code_or_error: CodeOrError,
4648
}
4749

48-
#[derive(Deserialize)]
50+
#[derive(Serialize, Deserialize)]
4951
#[serde(untagged)]
5052
enum CodeOrError {
5153
Code {
@@ -115,6 +117,7 @@ pub(crate) enum RouteError {
115117
Internal(Box<dyn std::error::Error>),
116118
}
117119

120+
impl_from_error_for_route!(mas_templates::TemplateError);
118121
impl_from_error_for_route!(mas_storage::RepositoryError);
119122
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
120123
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
@@ -152,23 +155,38 @@ pub(crate) async fn handler(
152155
State(encrypter): State<Encrypter>,
153156
State(keystore): State<Keystore>,
154157
State(client): State<reqwest::Client>,
158+
State(templates): State<Templates>,
159+
PreferredLanguage(locale): PreferredLanguage,
155160
cookie_jar: CookieJar,
156161
Path(provider_id): Path<Ulid>,
157162
query_params: Option<Query<Params>>,
158163
form_params: Option<Form<Params>>,
159-
) -> Result<impl IntoResponse, RouteError> {
164+
) -> Result<Response, RouteError> {
160165
let provider = repo
161166
.upstream_oauth_provider()
162167
.lookup(provider_id)
163168
.await?
164169
.filter(UpstreamOAuthProvider::enabled)
165170
.ok_or(RouteError::ProviderNotFound)?;
166171

172+
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
173+
167174
// Read the parameters from the query or the form, depending on what
168175
// response_mode the provider uses
169176
let params = match (provider.response_mode, query_params, form_params) {
170-
(UpstreamOAuthProviderResponseMode::Query, Some(query_params), None) => query_params.0,
171-
(UpstreamOAuthProviderResponseMode::FormPost, None, Some(form_params)) => form_params.0,
177+
(UpstreamOAuthProviderResponseMode::Query, Some(Query(query_params)), None) => query_params,
178+
(UpstreamOAuthProviderResponseMode::FormPost, None, Some(Form(form_params))) => {
179+
// We got there from a cross-site form POST, so we need to render a form with
180+
// the same values, which posts back to the same URL
181+
if sessions_cookie.is_empty() {
182+
let context =
183+
FormPostContext::new_for_current_url(form_params).with_language(&locale);
184+
let html = templates.render_form_post(&context)?;
185+
return Ok(Html(html).into_response());
186+
}
187+
188+
form_params
189+
}
172190
(UpstreamOAuthProviderResponseMode::Query, None, None) => {
173191
return Err(RouteError::MissingQueryParams)
174192
}
@@ -179,7 +197,6 @@ pub(crate) async fn handler(
179197
(expected, _, _) => return Err(RouteError::InvalidParamsMode { expected }),
180198
};
181199

182-
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
183200
let (session_id, _post_auth_action) = sessions_cookie
184201
.find_session(provider_id, &params.state)
185202
.map_err(|_| RouteError::MissingCookie)?;
@@ -327,5 +344,6 @@ pub(crate) async fn handler(
327344
Ok((
328345
cookie_jar,
329346
url_builder.redirect(&mas_router::UpstreamOAuth2Link::new(link.id)),
330-
))
347+
)
348+
.into_response())
331349
}

crates/handlers/src/upstream_oauth2/cookie.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ impl UpstreamSessions {
6161
}
6262
}
6363

64+
/// Returns true if the cookie is empty
65+
pub fn is_empty(&self) -> bool {
66+
self.0.is_empty()
67+
}
68+
6469
/// Save the upstreams sessions to the cookie jar
6570
pub fn save<C>(self, cookie_jar: CookieJar, clock: &C) -> CookieJar
6671
where

crates/templates/src/context.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,7 +1460,7 @@ impl TemplateContext for DeviceConsentContext {
14601460
/// Context used by the `form_post.html` template
14611461
#[derive(Serialize)]
14621462
pub struct FormPostContext<T> {
1463-
redirect_uri: Url,
1463+
redirect_uri: Option<Url>,
14641464
params: T,
14651465
}
14661466

@@ -1473,21 +1473,42 @@ impl<T: TemplateContext> TemplateContext for FormPostContext<T> {
14731473
sample_params
14741474
.into_iter()
14751475
.map(|params| FormPostContext {
1476-
redirect_uri: "https://example.com/callback".parse().unwrap(),
1476+
redirect_uri: "https://example.com/callback".parse().ok(),
14771477
params,
14781478
})
14791479
.collect()
14801480
}
14811481
}
14821482

14831483
impl<T> FormPostContext<T> {
1484-
/// Constructs a context for the `form_post` response mode form
1485-
pub fn new(redirect_uri: Url, params: T) -> Self {
1484+
/// Constructs a context for the `form_post` response mode form for a given
1485+
/// URL
1486+
pub fn new_for_url(redirect_uri: Url, params: T) -> Self {
14861487
Self {
1487-
redirect_uri,
1488+
redirect_uri: Some(redirect_uri),
14881489
params,
14891490
}
14901491
}
1492+
1493+
/// Constructs a context for the `form_post` response mode form for the
1494+
/// current URL
1495+
pub fn new_for_current_url(params: T) -> Self {
1496+
Self {
1497+
redirect_uri: None,
1498+
params,
1499+
}
1500+
}
1501+
1502+
/// Add the language to the context
1503+
///
1504+
/// This is usually implemented by the [`TemplateContext`] trait, but it is
1505+
/// annoying to make it work because of the generic parameter
1506+
pub fn with_language(self, lang: &DataLocale) -> WithLanguage<Self> {
1507+
WithLanguage {
1508+
lang: lang.to_string(),
1509+
inner: self,
1510+
}
1511+
}
14911512
}
14921513

14931514
/// Context used by the `error.html` template

crates/templates/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ register_templates! {
368368
pub fn render_reauth(WithLanguage<WithCsrf<WithSession<ReauthContext>>>) { "pages/reauth.html" }
369369

370370
/// Render the form used by the form_post response mode
371-
pub fn render_form_post<T: Serialize>(FormPostContext<T>) { "form_post.html" }
371+
pub fn render_form_post<T: Serialize>(WithLanguage<FormPostContext<T>>) { "form_post.html" }
372372

373373
/// Render the HTML error page
374374
pub fn render_error(ErrorContext) { "pages/error.html" }

0 commit comments

Comments
 (0)