66
77use axum:: {
88 extract:: { Path , Query , State } ,
9- response:: IntoResponse ,
9+ response:: { IntoResponse , Response } ,
1010 Form ,
1111} ;
12+ use axum_extra:: response:: Html ;
1213use hyper:: StatusCode ;
1314use mas_axum_utils:: { cookies:: CookieJar , sentry:: SentryEventID } ;
1415use mas_data_model:: { UpstreamOAuthProvider , UpstreamOAuthProviderResponseMode } ;
@@ -24,8 +25,9 @@ use mas_storage::{
2425 } ,
2526 BoxClock , BoxRepository , BoxRng , Clock ,
2627} ;
28+ use mas_templates:: { FormPostContext , Templates } ;
2729use oauth2_types:: errors:: ClientErrorCode ;
28- use serde:: Deserialize ;
30+ use serde:: { Deserialize , Serialize } ;
2931use thiserror:: Error ;
3032use ulid:: Ulid ;
3133
@@ -35,17 +37,17 @@ use super::{
3537 template:: { environment, AttributeMappingContext } ,
3638 UpstreamSessionsCookie ,
3739} ;
38- use crate :: { impl_from_error_for_route, upstream_oauth2:: cache:: MetadataCache } ;
40+ use crate :: { impl_from_error_for_route, upstream_oauth2:: cache:: MetadataCache , PreferredLanguage } ;
3941
40- #[ derive( Deserialize ) ]
42+ #[ derive( Serialize , Deserialize ) ]
4143pub struct Params {
4244 state : String ,
4345
4446 #[ serde( flatten) ]
4547 code_or_error : CodeOrError ,
4648}
4749
48- #[ derive( Deserialize ) ]
50+ #[ derive( Serialize , Deserialize ) ]
4951#[ serde( untagged) ]
5052enum CodeOrError {
5153 Code {
@@ -115,6 +117,7 @@ pub(crate) enum RouteError {
115117 Internal ( Box < dyn std:: error:: Error > ) ,
116118}
117119
120+ impl_from_error_for_route ! ( mas_templates:: TemplateError ) ;
118121impl_from_error_for_route ! ( mas_storage:: RepositoryError ) ;
119122impl_from_error_for_route ! ( mas_oidc_client:: error:: DiscoveryError ) ;
120123impl_from_error_for_route ! ( mas_oidc_client:: error:: JwksError ) ;
@@ -152,23 +155,38 @@ pub(crate) async fn handler(
152155 State ( encrypter) : State < Encrypter > ,
153156 State ( keystore) : State < Keystore > ,
154157 State ( client) : State < reqwest:: Client > ,
158+ State ( templates) : State < Templates > ,
159+ PreferredLanguage ( locale) : PreferredLanguage ,
155160 cookie_jar : CookieJar ,
156161 Path ( provider_id) : Path < Ulid > ,
157162 query_params : Option < Query < Params > > ,
158163 form_params : Option < Form < Params > > ,
159- ) -> Result < impl IntoResponse , RouteError > {
164+ ) -> Result < Response , RouteError > {
160165 let provider = repo
161166 . upstream_oauth_provider ( )
162167 . lookup ( provider_id)
163168 . await ?
164169 . filter ( UpstreamOAuthProvider :: enabled)
165170 . ok_or ( RouteError :: ProviderNotFound ) ?;
166171
172+ let sessions_cookie = UpstreamSessionsCookie :: load ( & cookie_jar) ;
173+
167174 // Read the parameters from the query or the form, depending on what
168175 // response_mode the provider uses
169176 let params = match ( provider. response_mode , query_params, form_params) {
170- ( UpstreamOAuthProviderResponseMode :: Query , Some ( query_params) , None ) => query_params. 0 ,
171- ( UpstreamOAuthProviderResponseMode :: FormPost , None , Some ( form_params) ) => form_params. 0 ,
177+ ( UpstreamOAuthProviderResponseMode :: Query , Some ( Query ( query_params) ) , None ) => query_params,
178+ ( UpstreamOAuthProviderResponseMode :: FormPost , None , Some ( Form ( form_params) ) ) => {
179+ // We got there from a cross-site form POST, so we need to render a form with
180+ // the same values, which posts back to the same URL
181+ if sessions_cookie. is_empty ( ) {
182+ let context =
183+ FormPostContext :: new_for_current_url ( form_params) . with_language ( & locale) ;
184+ let html = templates. render_form_post ( & context) ?;
185+ return Ok ( Html ( html) . into_response ( ) ) ;
186+ }
187+
188+ form_params
189+ }
172190 ( UpstreamOAuthProviderResponseMode :: Query , None , None ) => {
173191 return Err ( RouteError :: MissingQueryParams )
174192 }
@@ -179,7 +197,6 @@ pub(crate) async fn handler(
179197 ( expected, _, _) => return Err ( RouteError :: InvalidParamsMode { expected } ) ,
180198 } ;
181199
182- let sessions_cookie = UpstreamSessionsCookie :: load ( & cookie_jar) ;
183200 let ( session_id, _post_auth_action) = sessions_cookie
184201 . find_session ( provider_id, & params. state )
185202 . map_err ( |_| RouteError :: MissingCookie ) ?;
@@ -327,5 +344,6 @@ pub(crate) async fn handler(
327344 Ok ( (
328345 cookie_jar,
329346 url_builder. redirect ( & mas_router:: UpstreamOAuth2Link :: new ( link. id ) ) ,
330- ) )
347+ )
348+ . into_response ( ) )
331349}
0 commit comments