3030 Optional ,
3131 Tuple ,
3232 Union ,
33+ cast ,
3334)
3435
3536import attr
7273from synapse .util .threepids import canonicalise_email
7374
7475if TYPE_CHECKING :
76+ from synapse .rest .client .v1 .login import LoginResponse
7577 from synapse .server import HomeServer
7678
7779logger = logging .getLogger (__name__ )
@@ -777,13 +779,116 @@ def _auth_dict_for_flows(
777779 "params" : params ,
778780 }
779781
782+ async def refresh_token (
783+ self ,
784+ refresh_token : str ,
785+ valid_until_ms : Optional [int ],
786+ ) -> Tuple [str , str ]:
787+ """
788+ Consumes a refresh token and generate both a new access token and a new refresh token from it.
789+
790+ The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
791+
792+ Args:
793+ refresh_token: The token to consume.
794+ valid_until_ms: The expiration timestamp of the new access token.
795+
796+ Returns:
797+ A tuple containing the new access token and refresh token
798+ """
799+
800+ # Verify the token signature first before looking up the token
801+ if not self ._verify_refresh_token (refresh_token ):
802+ raise SynapseError (401 , "invalid refresh token" , Codes .UNKNOWN_TOKEN )
803+
804+ existing_token = await self .store .lookup_refresh_token (refresh_token )
805+ if existing_token is None :
806+ raise SynapseError (401 , "refresh token does not exist" , Codes .UNKNOWN_TOKEN )
807+
808+ if (
809+ existing_token .has_next_access_token_been_used
810+ or existing_token .has_next_refresh_token_been_refreshed
811+ ):
812+ raise SynapseError (
813+ 403 , "refresh token isn't valid anymore" , Codes .FORBIDDEN
814+ )
815+
816+ (
817+ new_refresh_token ,
818+ new_refresh_token_id ,
819+ ) = await self .get_refresh_token_for_user_id (
820+ user_id = existing_token .user_id , device_id = existing_token .device_id
821+ )
822+ access_token = await self .get_access_token_for_user_id (
823+ user_id = existing_token .user_id ,
824+ device_id = existing_token .device_id ,
825+ valid_until_ms = valid_until_ms ,
826+ refresh_token_id = new_refresh_token_id ,
827+ )
828+ await self .store .replace_refresh_token (
829+ existing_token .token_id , new_refresh_token_id
830+ )
831+ return access_token , new_refresh_token
832+
833+ def _verify_refresh_token (self , token : str ) -> bool :
834+ """
835+ Verifies the shape of a refresh token.
836+
837+ Args:
838+ token: The refresh token to verify
839+
840+ Returns:
841+ Whether the token has the right shape
842+ """
843+ parts = token .split ("_" , maxsplit = 4 )
844+ if len (parts ) != 4 :
845+ return False
846+
847+ type , localpart , rand , crc = parts
848+
849+ # Refresh tokens are prefixed by "syr_", let's check that
850+ if type != "syr" :
851+ return False
852+
853+ # Check the CRC
854+ base = f"{ type } _{ localpart } _{ rand } "
855+ expected_crc = base62_encode (crc32 (base .encode ("ascii" )), minwidth = 6 )
856+ if crc != expected_crc :
857+ return False
858+
859+ return True
860+
861+ async def get_refresh_token_for_user_id (
862+ self ,
863+ user_id : str ,
864+ device_id : str ,
865+ ) -> Tuple [str , int ]:
866+ """
867+ Creates a new refresh token for the user with the given user ID.
868+
869+ Args:
870+ user_id: canonical user ID
871+ device_id: the device ID to associate with the token.
872+
873+ Returns:
874+ The newly created refresh token and its ID in the database
875+ """
876+ refresh_token = self .generate_refresh_token (UserID .from_string (user_id ))
877+ refresh_token_id = await self .store .add_refresh_token_to_user (
878+ user_id = user_id ,
879+ token = refresh_token ,
880+ device_id = device_id ,
881+ )
882+ return refresh_token , refresh_token_id
883+
780884 async def get_access_token_for_user_id (
781885 self ,
782886 user_id : str ,
783887 device_id : Optional [str ],
784888 valid_until_ms : Optional [int ],
785889 puppets_user_id : Optional [str ] = None ,
786890 is_appservice_ghost : bool = False ,
891+ refresh_token_id : Optional [int ] = None ,
787892 ) -> str :
788893 """
789894 Creates a new access token for the user with the given user ID.
@@ -801,6 +906,8 @@ async def get_access_token_for_user_id(
801906 valid_until_ms: when the token is valid until. None for
802907 no expiry.
803908 is_appservice_ghost: Whether the user is an application ghost user
909+ refresh_token_id: the refresh token ID that will be associated with
910+ this access token.
804911 Returns:
805912 The access token for the user's session.
806913 Raises:
@@ -836,6 +943,7 @@ async def get_access_token_for_user_id(
836943 device_id = device_id ,
837944 valid_until_ms = valid_until_ms ,
838945 puppets_user_id = puppets_user_id ,
946+ refresh_token_id = refresh_token_id ,
839947 )
840948
841949 # the device *should* have been registered before we got here; however,
@@ -928,7 +1036,7 @@ async def validate_login(
9281036 self ,
9291037 login_submission : Dict [str , Any ],
9301038 ratelimit : bool = False ,
931- ) -> Tuple [str , Optional [Callable [[Dict [ str , str ] ], Awaitable [None ]]]]:
1039+ ) -> Tuple [str , Optional [Callable [["LoginResponse" ], Awaitable [None ]]]]:
9321040 """Authenticates the user for the /login API
9331041
9341042 Also used by the user-interactive auth flow to validate auth types which don't
@@ -1073,7 +1181,7 @@ async def _validate_userid_login(
10731181 self ,
10741182 username : str ,
10751183 login_submission : Dict [str , Any ],
1076- ) -> Tuple [str , Optional [Callable [[Dict [ str , str ] ], Awaitable [None ]]]]:
1184+ ) -> Tuple [str , Optional [Callable [["LoginResponse" ], Awaitable [None ]]]]:
10771185 """Helper for validate_login
10781186
10791187 Handles login, once we've mapped 3pids onto userids
@@ -1151,7 +1259,7 @@ async def _validate_userid_login(
11511259
11521260 async def check_password_provider_3pid (
11531261 self , medium : str , address : str , password : str
1154- ) -> Tuple [Optional [str ], Optional [Callable [[Dict [ str , str ] ], Awaitable [None ]]]]:
1262+ ) -> Tuple [Optional [str ], Optional [Callable [["LoginResponse" ], Awaitable [None ]]]]:
11551263 """Check if a password provider is able to validate a thirdparty login
11561264
11571265 Args:
@@ -1215,6 +1323,19 @@ def generate_access_token(self, for_user: UserID) -> str:
12151323 crc = base62_encode (crc32 (base .encode ("ascii" )), minwidth = 6 )
12161324 return f"{ base } _{ crc } "
12171325
1326+ def generate_refresh_token (self , for_user : UserID ) -> str :
1327+ """Generates an opaque string, for use as a refresh token"""
1328+
1329+ # we use the following format for refresh tokens:
1330+ # syr_<base64 local part>_<random string>_<base62 crc check>
1331+
1332+ b64local = unpaddedbase64 .encode_base64 (for_user .localpart .encode ("utf-8" ))
1333+ random_string = stringutils .random_string (20 )
1334+ base = f"syr_{ b64local } _{ random_string } "
1335+
1336+ crc = base62_encode (crc32 (base .encode ("ascii" )), minwidth = 6 )
1337+ return f"{ base } _{ crc } "
1338+
12181339 async def validate_short_term_login_token (
12191340 self , login_token : str
12201341 ) -> LoginTokenAttributes :
@@ -1563,7 +1684,7 @@ def _complete_sso_login(
15631684 )
15641685 respond_with_html (request , 200 , html )
15651686
1566- async def _sso_login_callback (self , login_result : JsonDict ) -> None :
1687+ async def _sso_login_callback (self , login_result : "LoginResponse" ) -> None :
15671688 """
15681689 A login callback which might add additional attributes to the login response.
15691690
@@ -1577,7 +1698,8 @@ async def _sso_login_callback(self, login_result: JsonDict) -> None:
15771698
15781699 extra_attributes = self ._extra_attributes .get (login_result ["user_id" ])
15791700 if extra_attributes :
1580- login_result .update (extra_attributes .extra_attributes )
1701+ login_result_dict = cast (Dict [str , Any ], login_result )
1702+ login_result_dict .update (extra_attributes .extra_attributes )
15811703
15821704 def _expire_sso_extra_attributes (self ) -> None :
15831705 """
0 commit comments