1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import logging
15- from typing import List , Optional , Tuple
15+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
1616
1717import pymacaroons
1818from netaddr import IPAddress
1919
2020from twisted .web .server import Request
2121
22- import synapse .types
2322from synapse import event_auth
2423from synapse .api .auth_blocking import AuthBlocking
2524from synapse .api .constants import EventTypes , HistoryVisibility , Membership
3635from synapse .http .site import SynapseRequest
3736from synapse .logging import opentracing as opentracing
3837from synapse .storage .databases .main .registration import TokenLookupResult
39- from synapse .types import StateMap , UserID
38+ from synapse .types import Requester , StateMap , UserID , create_requester
4039from synapse .util .caches .lrucache import LruCache
4140from synapse .util .macaroons import get_value_from_macaroon , satisfy_expiry
4241from synapse .util .metrics import Measure
4342
43+ if TYPE_CHECKING :
44+ from synapse .server import HomeServer
45+
4446logger = logging .getLogger (__name__ )
4547
4648
@@ -68,7 +70,7 @@ class Auth:
6870 The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
6971 """
7072
71- def __init__ (self , hs ):
73+ def __init__ (self , hs : "HomeServer" ):
7274 self .hs = hs
7375 self .clock = hs .get_clock ()
7476 self .store = hs .get_datastore ()
@@ -88,13 +90,13 @@ def __init__(self, hs):
8890
8991 async def check_from_context (
9092 self , room_version : str , event , context , do_sig_check = True
91- ):
93+ ) -> None :
9294 prev_state_ids = await context .get_prev_state_ids ()
9395 auth_events_ids = self .compute_auth_events (
9496 event , prev_state_ids , for_verification = True
9597 )
96- auth_events = await self .store .get_events (auth_events_ids )
97- auth_events = {(e .type , e .state_key ): e for e in auth_events .values ()}
98+ auth_events_by_id = await self .store .get_events (auth_events_ids )
99+ auth_events = {(e .type , e .state_key ): e for e in auth_events_by_id .values ()}
98100
99101 room_version_obj = KNOWN_ROOM_VERSIONS [room_version ]
100102 event_auth .check (
@@ -151,17 +153,11 @@ async def check_user_in_room(
151153
152154 raise AuthError (403 , "User %s not in room %s" % (user_id , room_id ))
153155
154- async def check_host_in_room (self , room_id , host ) :
156+ async def check_host_in_room (self , room_id : str , host : str ) -> bool :
155157 with Measure (self .clock , "check_host_in_room" ):
156- latest_event_ids = await self .store .is_host_joined (room_id , host )
157- return latest_event_ids
158-
159- def can_federate (self , event , auth_events ):
160- creation_event = auth_events .get ((EventTypes .Create , "" ))
158+ return await self .store .is_host_joined (room_id , host )
161159
162- return creation_event .content .get ("m.federate" , True ) is True
163-
164- def get_public_keys (self , invite_event ):
160+ def get_public_keys (self , invite_event : EventBase ) -> List [Dict [str , Any ]]:
165161 return event_auth .get_public_keys (invite_event )
166162
167163 async def get_user_by_req (
@@ -170,7 +166,7 @@ async def get_user_by_req(
170166 allow_guest : bool = False ,
171167 rights : str = "access" ,
172168 allow_expired : bool = False ,
173- ) -> synapse . types . Requester :
169+ ) -> Requester :
174170 """Get a registered user's ID.
175171
176172 Args:
@@ -196,7 +192,7 @@ async def get_user_by_req(
196192 access_token = self .get_access_token_from_request (request )
197193
198194 user_id , app_service = await self ._get_appservice_user_id (request )
199- if user_id :
195+ if user_id and app_service :
200196 if ip_addr and self ._track_appservice_user_ips :
201197 await self .store .insert_client_ip (
202198 user_id = user_id ,
@@ -206,9 +202,7 @@ async def get_user_by_req(
206202 device_id = "dummy-device" , # stubbed
207203 )
208204
209- requester = synapse .types .create_requester (
210- user_id , app_service = app_service
211- )
205+ requester = create_requester (user_id , app_service = app_service )
212206
213207 request .requester = user_id
214208 opentracing .set_tag ("authenticated_entity" , user_id )
@@ -251,7 +245,7 @@ async def get_user_by_req(
251245 errcode = Codes .GUEST_ACCESS_FORBIDDEN ,
252246 )
253247
254- requester = synapse . types . create_requester (
248+ requester = create_requester (
255249 user_info .user_id ,
256250 token_id ,
257251 is_guest ,
@@ -271,7 +265,9 @@ async def get_user_by_req(
271265 except KeyError :
272266 raise MissingClientTokenError ()
273267
274- async def _get_appservice_user_id (self , request ):
268+ async def _get_appservice_user_id (
269+ self , request : Request
270+ ) -> Tuple [Optional [str ], Optional [ApplicationService ]]:
275271 app_service = self .store .get_app_service_by_token (
276272 self .get_access_token_from_request (request )
277273 )
@@ -283,6 +279,9 @@ async def _get_appservice_user_id(self, request):
283279 if ip_address not in app_service .ip_range_whitelist :
284280 return None , None
285281
282+ # This will always be set by the time Twisted calls us.
283+ assert request .args is not None
284+
286285 if b"user_id" not in request .args :
287286 return app_service .sender , app_service
288287
@@ -387,7 +386,9 @@ async def get_user_by_access_token(
387386 logger .warning ("Invalid macaroon in auth: %s %s" , type (e ), e )
388387 raise InvalidClientTokenError ("Invalid macaroon passed." )
389388
390- def _parse_and_validate_macaroon (self , token , rights = "access" ):
389+ def _parse_and_validate_macaroon (
390+ self , token : str , rights : str = "access"
391+ ) -> Tuple [str , bool ]:
391392 """Takes a macaroon and tries to parse and validate it. This is cached
392393 if and only if rights == access and there isn't an expiry.
393394
@@ -432,15 +433,16 @@ def _parse_and_validate_macaroon(self, token, rights="access"):
432433
433434 return user_id , guest
434435
435- def validate_macaroon (self , macaroon , type_string , user_id ):
436+ def validate_macaroon (
437+ self , macaroon : pymacaroons .Macaroon , type_string : str , user_id : str
438+ ) -> None :
436439 """
437440 validate that a Macaroon is understood by and was signed by this server.
438441
439442 Args:
440- macaroon(pymacaroons.Macaroon): The macaroon to validate
441- type_string(str): The kind of token required (e.g. "access",
442- "delete_pusher")
443- user_id (str): The user_id required
443+ macaroon: The macaroon to validate
444+ type_string: The kind of token required (e.g. "access", "delete_pusher")
445+ user_id: The user_id required
444446 """
445447 v = pymacaroons .Verifier ()
446448
@@ -465,9 +467,7 @@ def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
465467 if not service :
466468 logger .warning ("Unrecognised appservice access token." )
467469 raise InvalidClientTokenError ()
468- request .requester = synapse .types .create_requester (
469- service .sender , app_service = service
470- )
470+ request .requester = create_requester (service .sender , app_service = service )
471471 return service
472472
473473 async def is_server_admin (self , user : UserID ) -> bool :
@@ -519,7 +519,7 @@ def compute_auth_events(
519519
520520 return auth_ids
521521
522- async def check_can_change_room_list (self , room_id : str , user : UserID ):
522+ async def check_can_change_room_list (self , room_id : str , user : UserID ) -> bool :
523523 """Determine whether the user is allowed to edit the room's entry in the
524524 published room list.
525525
@@ -554,11 +554,11 @@ async def check_can_change_room_list(self, room_id: str, user: UserID):
554554 return user_level >= send_level
555555
556556 @staticmethod
557- def has_access_token (request : Request ):
557+ def has_access_token (request : Request ) -> bool :
558558 """Checks if the request has an access_token.
559559
560560 Returns:
561- bool: False if no access_token was given, True otherwise.
561+ False if no access_token was given, True otherwise.
562562 """
563563 # This will always be set by the time Twisted calls us.
564564 assert request .args is not None
@@ -568,13 +568,13 @@ def has_access_token(request: Request):
568568 return bool (query_params ) or bool (auth_headers )
569569
570570 @staticmethod
571- def get_access_token_from_request (request : Request ):
571+ def get_access_token_from_request (request : Request ) -> str :
572572 """Extracts the access_token from the request.
573573
574574 Args:
575575 request: The http request.
576576 Returns:
577- unicode: The access_token
577+ The access_token
578578 Raises:
579579 MissingClientTokenError: If there isn't a single access_token in the
580580 request
@@ -649,5 +649,5 @@ async def check_user_in_room_or_world_readable(
649649 % (user_id , room_id ),
650650 )
651651
652- def check_auth_blocking (self , * args , ** kwargs ):
653- return self ._auth_blocking .check_auth_blocking (* args , ** kwargs )
652+ async def check_auth_blocking (self , * args , ** kwargs ) -> None :
653+ await self ._auth_blocking .check_auth_blocking (* args , ** kwargs )
0 commit comments