1616import logging
1717import re
1818import time
19- from typing import TYPE_CHECKING , Any , Awaitable , Callable , Optional , Tuple , cast
19+ from http import HTTPStatus
20+ from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict , Optional , Tuple , cast
2021
2122from synapse .api .errors import Codes , FederationDeniedError , SynapseError
2223from synapse .api .urls import FEDERATION_V1_PREFIX
@@ -86,15 +87,24 @@ async def authenticate_request(
8687
8788 if not auth_headers :
8889 raise NoAuthenticationError (
89- 401 , "Missing Authorization headers" , Codes .UNAUTHORIZED
90+ HTTPStatus .UNAUTHORIZED ,
91+ "Missing Authorization headers" ,
92+ Codes .UNAUTHORIZED ,
9093 )
9194
9295 for auth in auth_headers :
9396 if auth .startswith (b"X-Matrix" ):
94- (origin , key , sig ) = _parse_auth_header (auth )
97+ (origin , key , sig , destination ) = _parse_auth_header (auth )
9598 json_request ["origin" ] = origin
9699 json_request ["signatures" ].setdefault (origin , {})[key ] = sig
97100
101+ # if the origin_server sent a destination along it needs to match our own server_name
102+ if destination is not None and destination != self .server_name :
103+ raise AuthenticationError (
104+ HTTPStatus .UNAUTHORIZED ,
105+ "Destination mismatch in auth header" ,
106+ Codes .UNAUTHORIZED ,
107+ )
98108 if (
99109 self .federation_domain_whitelist is not None
100110 and origin not in self .federation_domain_whitelist
@@ -103,7 +113,9 @@ async def authenticate_request(
103113
104114 if origin is None or not json_request ["signatures" ]:
105115 raise NoAuthenticationError (
106- 401 , "Missing Authorization headers" , Codes .UNAUTHORIZED
116+ HTTPStatus .UNAUTHORIZED ,
117+ "Missing Authorization headers" ,
118+ Codes .UNAUTHORIZED ,
107119 )
108120
109121 await self .keyring .verify_json_for_server (
@@ -142,13 +154,14 @@ async def reset_retry_timings(self, origin: str) -> None:
142154 logger .exception ("Error resetting retry timings on %s" , origin )
143155
144156
145- def _parse_auth_header (header_bytes : bytes ) -> Tuple [str , str , str ]:
157+ def _parse_auth_header (header_bytes : bytes ) -> Tuple [str , str , str , Optional [ str ] ]:
146158 """Parse an X-Matrix auth header
147159
148160 Args:
149161 header_bytes: header value
150162
151163 Returns:
164+ origin, key id, signature, destination.
152165 origin, key id, signature.
153166
154167 Raises:
@@ -157,7 +170,9 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
157170 try :
158171 header_str = header_bytes .decode ("utf-8" )
159172 params = header_str .split (" " )[1 ].split ("," )
160- param_dict = {k : v for k , v in (kv .split ("=" , maxsplit = 1 ) for kv in params )}
173+ param_dict : Dict [str , str ] = {
174+ k : v for k , v in [param .split ("=" , maxsplit = 1 ) for param in params ]
175+ }
161176
162177 def strip_quotes (value : str ) -> str :
163178 if value .startswith ('"' ):
@@ -172,15 +187,23 @@ def strip_quotes(value: str) -> str:
172187
173188 key = strip_quotes (param_dict ["key" ])
174189 sig = strip_quotes (param_dict ["sig" ])
175- return origin , key , sig
190+
191+ # get the destination server_name from the auth header if it exists
192+ destination = param_dict .get ("destination" )
193+ if destination is not None :
194+ destination = strip_quotes (destination )
195+ else :
196+ destination = None
197+
198+ return origin , key , sig , destination
176199 except Exception as e :
177200 logger .warning (
178201 "Error parsing auth header '%s': %s" ,
179202 header_bytes .decode ("ascii" , "replace" ),
180203 e ,
181204 )
182205 raise AuthenticationError (
183- 400 , "Malformed Authorization header" , Codes .UNAUTHORIZED
206+ HTTPStatus . BAD_REQUEST , "Malformed Authorization header" , Codes .UNAUTHORIZED
184207 )
185208
186209
0 commit comments