@@ -170,9 +170,6 @@ pub enum RouteError {
170170 #[ error( "user not found" ) ]
171171 UserNotFound ,
172172
173- #[ error( "session not found" ) ]
174- SessionNotFound ,
175-
176173 #[ error( "user has no password" ) ]
177174 NoPassword ,
178175
@@ -201,13 +198,11 @@ impl IntoResponse for RouteError {
201198 fn into_response ( self ) -> axum:: response:: Response {
202199 let event_id = sentry:: capture_error ( & self ) ;
203200 let response = match self {
204- Self :: Internal ( _) | Self :: SessionNotFound | Self :: ProvisionDeviceFailed ( _) => {
205- MatrixError {
206- errcode : "M_UNKNOWN" ,
207- error : "Internal server error" ,
208- status : StatusCode :: INTERNAL_SERVER_ERROR ,
209- }
210- }
201+ Self :: Internal ( _) | Self :: ProvisionDeviceFailed ( _) => MatrixError {
202+ errcode : "M_UNKNOWN" ,
203+ error : "Internal server error" ,
204+ status : StatusCode :: INTERNAL_SERVER_ERROR ,
205+ } ,
211206 Self :: RateLimited ( _) => MatrixError {
212207 errcode : "M_LIMIT_EXCEEDED" ,
213208 error : "Too many login attempts" ,
@@ -323,7 +318,17 @@ pub(crate) async fn post(
323318 . await ?
324319 }
325320
326- ( _, Credentials :: Token { token } ) => token_login ( & mut repo, & clock, & token) . await ?,
321+ ( _, Credentials :: Token { token } ) => {
322+ token_login (
323+ & mut repo,
324+ & clock,
325+ & token,
326+ input. device_id ,
327+ & homeserver,
328+ & mut rng,
329+ )
330+ . await ?
331+ }
327332
328333 _ => {
329334 return Err ( RouteError :: Unsupported ) ;
@@ -382,6 +387,9 @@ async fn token_login(
382387 repo : & mut BoxRepository ,
383388 clock : & dyn Clock ,
384389 token : & str ,
390+ requested_device_id : Option < String > ,
391+ homeserver : & dyn HomeserverConnection ,
392+ rng : & mut ( dyn RngCore + Send ) ,
385393) -> Result < ( CompatSession , User ) , RouteError > {
386394 let login = repo
387395 . compat_sso_login ( )
@@ -390,7 +398,7 @@ async fn token_login(
390398 . ok_or ( RouteError :: InvalidLoginToken ) ?;
391399
392400 let now = clock. now ( ) ;
393- let session_id = match login. state {
401+ let browser_session_id = match login. state {
394402 CompatSsoLoginState :: Pending => {
395403 tracing:: error!(
396404 compat_sso_login. id = %login. id,
@@ -400,25 +408,25 @@ async fn token_login(
400408 }
401409 CompatSsoLoginState :: Fulfilled {
402410 fulfilled_at,
403- session_id ,
411+ browser_session_id ,
404412 ..
405413 } => {
406414 if now > fulfilled_at + Duration :: microseconds ( 30 * 1000 * 1000 ) {
407415 return Err ( RouteError :: LoginTookTooLong ) ;
408416 }
409417
410- session_id
418+ browser_session_id
411419 }
412420 CompatSsoLoginState :: Exchanged {
413421 exchanged_at,
414- session_id ,
422+ compat_session_id ,
415423 ..
416424 } => {
417425 if now > exchanged_at + Duration :: microseconds ( 30 * 1000 * 1000 ) {
418426 // TODO: log that session out
419427 tracing:: error!(
420428 compat_sso_login. id = %login. id,
421- compat_session. id = %session_id ,
429+ compat_session. id = %compat_session_id ,
422430 "Login token exchanged a second time more than 30s after"
423431 ) ;
424432 }
@@ -427,22 +435,56 @@ async fn token_login(
427435 }
428436 } ;
429437
430- let session = repo
431- . compat_session ( )
432- . lookup ( session_id)
433- . await ?
434- . ok_or ( RouteError :: SessionNotFound ) ?;
438+ let Some ( browser_session) = repo. browser_session ( ) . lookup ( browser_session_id) . await ? else {
439+ tracing:: error!(
440+ compat_sso_login. id = %login. id,
441+ browser_session. id = %browser_session_id,
442+ "Attempt to exchange login token but no associated browser session found"
443+ ) ;
444+ return Err ( RouteError :: InvalidLoginToken ) ;
445+ } ;
446+ if !browser_session. active ( ) || !browser_session. user . is_valid ( ) {
447+ tracing:: info!(
448+ compat_sso_login. id = %login. id,
449+ browser_session. id = %browser_session_id,
450+ "Attempt to exchange login token but browser session is not active"
451+ ) ;
452+ return Err ( RouteError :: InvalidLoginToken ) ;
453+ }
435454
436- let user = repo
437- . user ( )
438- . lookup ( session. user_id )
439- . await ?
440- . filter ( mas_data_model:: User :: is_valid)
441- . ok_or ( RouteError :: UserNotFound ) ?;
455+ // Lock the user sync to make sure we don't get into a race condition
456+ repo. user ( )
457+ . acquire_lock_for_sync ( & browser_session. user )
458+ . await ?;
442459
443- repo. compat_sso_login ( ) . exchange ( clock, login) . await ?;
460+ let device = if let Some ( requested_device_id) = requested_device_id {
461+ Device :: from ( requested_device_id)
462+ } else {
463+ Device :: generate ( rng)
464+ } ;
465+ let mxid = homeserver. mxid ( & browser_session. user . username ) ;
466+ homeserver
467+ . create_device ( & mxid, device. as_str ( ) )
468+ . await
469+ . map_err ( RouteError :: ProvisionDeviceFailed ) ?;
444470
445- Ok ( ( session, user) )
471+ let compat_session = repo
472+ . compat_session ( )
473+ . add (
474+ rng,
475+ clock,
476+ & browser_session. user ,
477+ device,
478+ Some ( & browser_session) ,
479+ false ,
480+ )
481+ . await ?;
482+
483+ repo. compat_sso_login ( )
484+ . exchange ( clock, login, & compat_session)
485+ . await ?;
486+
487+ Ok ( ( compat_session, browser_session. user ) )
446488}
447489
448490async fn user_password_login (
@@ -1015,7 +1057,7 @@ mod tests {
10151057 }
10161058 "### ) ;
10171059
1018- let ( device , token) = get_login_token ( & state, & user) . await ;
1060+ let token = get_login_token ( & state, & user) . await ;
10191061
10201062 // Try to login with the token.
10211063 let request = Request :: post ( "/_matrix/client/v3/login" ) . json ( serde_json:: json!( {
@@ -1026,14 +1068,13 @@ mod tests {
10261068 response. assert_status ( StatusCode :: OK ) ;
10271069
10281070 let body: serde_json:: Value = response. json ( ) ;
1029- insta:: assert_json_snapshot!( body, @r### "
1071+ insta:: assert_json_snapshot!( body, @r#"
10301072 {
1031- "access_token": "mct_uihy4bk51gxgUbUTa4XIh92RARTPTj_xADEE4 ",
1032- "device_id": "Yp7FM44zJN ",
1073+ "access_token": "mct_bnkWh1tPmm1MZOpygPaXwygX8PfxEY_hE6do1 ",
1074+ "device_id": "O3Ju1MUh3Z ",
10331075 "user_id": "@alice:example.com"
10341076 }
1035- "### ) ;
1036- assert_eq ! ( body[ "device_id" ] , device. to_string( ) ) ;
1077+ "# ) ;
10371078
10381079 // Try again with the same token, it should fail.
10391080 let request = Request :: post ( "/_matrix/client/v3/login" ) . json ( serde_json:: json!( {
@@ -1051,7 +1092,7 @@ mod tests {
10511092 "### ) ;
10521093
10531094 // Try to login, but wait too long before sending the request.
1054- let ( _device , token) = get_login_token ( & state, & user) . await ;
1095+ let token = get_login_token ( & state, & user) . await ;
10551096
10561097 // Advance the clock to make the token expire.
10571098 state
@@ -1079,14 +1120,13 @@ mod tests {
10791120 /// # Panics
10801121 ///
10811122 /// Panics if the repository fails.
1082- async fn get_login_token ( state : & TestState , user : & User ) -> ( Device , String ) {
1123+ async fn get_login_token ( state : & TestState , user : & User ) -> String {
10831124 // XXX: This is a bit manual, but this is what basically the SSO login flow
10841125 // does.
10851126 let mut repo = state. repository ( ) . await . unwrap ( ) ;
10861127
1087- // Generate a device and a token randomly
1128+ // Generate a token randomly
10881129 let token = Alphanumeric . sample_string ( & mut state. rng ( ) , 32 ) ;
1089- let device = Device :: generate ( & mut state. rng ( ) ) ;
10901130
10911131 // Start a compat SSO login flow
10921132 let login = repo
@@ -1100,27 +1140,20 @@ mod tests {
11001140 . await
11011141 . unwrap ( ) ;
11021142
1103- // Complete the flow by fulfilling it with a session
1104- let compat_session = repo
1105- . compat_session ( )
1106- . add (
1107- & mut state. rng ( ) ,
1108- & state. clock ,
1109- user,
1110- device. clone ( ) ,
1111- None ,
1112- false ,
1113- )
1143+ // Advance the flow by fulfilling it with a browser session
1144+ let browser_session = repo
1145+ . browser_session ( )
1146+ . add ( & mut state. rng ( ) , & state. clock , user, None )
11141147 . await
11151148 . unwrap ( ) ;
1116-
1117- repo . compat_sso_login ( )
1118- . fulfill ( & state. clock , login, & compat_session )
1149+ let _login = repo
1150+ . compat_sso_login ( )
1151+ . fulfill ( & state. clock , login, & browser_session )
11191152 . await
11201153 . unwrap ( ) ;
11211154
11221155 repo. save ( ) . await . unwrap ( ) ;
11231156
1124- ( device , token)
1157+ token
11251158 }
11261159}
0 commit comments