4444logger = logging .getLogger (__name__ )
4545
4646
47+ class ResolveRoomIdMixin :
48+ def __init__ (self , hs : "HomeServer" ):
49+ self .room_member_handler = hs .get_room_member_handler ()
50+
51+ async def resolve_room_id (
52+ self , room_identifier : str , remote_room_hosts : Optional [List [str ]] = None
53+ ) -> Tuple [str , Optional [List [str ]]]:
54+ """
55+ Resolve a room identifier to a room ID, if necessary.
56+
57+ This also performanes checks to ensure the room ID is of the proper form.
58+
59+ Args:
60+ room_identifier: The room ID or alias.
61+ remote_room_hosts: The potential remote room hosts to use.
62+
63+ Returns:
64+ The resolved room ID.
65+
66+ Raises:
67+ SynapseError if the room ID is of the wrong form.
68+ """
69+ if RoomID .is_valid (room_identifier ):
70+ resolved_room_id = room_identifier
71+ elif RoomAlias .is_valid (room_identifier ):
72+ room_alias = RoomAlias .from_string (room_identifier )
73+ (
74+ room_id ,
75+ remote_room_hosts ,
76+ ) = await self .room_member_handler .lookup_room_alias (room_alias )
77+ resolved_room_id = room_id .to_string ()
78+ else :
79+ raise SynapseError (
80+ 400 , "%s was not legal room ID or room alias" % (room_identifier ,)
81+ )
82+ if not resolved_room_id :
83+ raise SynapseError (
84+ 400 , "Unknown room ID or room alias %s" % room_identifier
85+ )
86+ return resolved_room_id , remote_room_hosts
87+
88+
4789class ShutdownRoomRestServlet (RestServlet ):
4890 """Shuts down a room by removing all local users from the room and blocking
4991 all future invites and joins to the room. Any local aliases will be repointed
@@ -334,14 +376,14 @@ async def on_GET(
334376 return 200 , ret
335377
336378
337- class JoinRoomAliasServlet (RestServlet ):
379+ class JoinRoomAliasServlet (ResolveRoomIdMixin , RestServlet ):
338380
339381 PATTERNS = admin_patterns ("/join/(?P<room_identifier>[^/]*)" )
340382
341383 def __init__ (self , hs : "HomeServer" ):
384+ super ().__init__ (hs )
342385 self .hs = hs
343386 self .auth = hs .get_auth ()
344- self .room_member_handler = hs .get_room_member_handler ()
345387 self .admin_handler = hs .get_admin_handler ()
346388 self .state_handler = hs .get_state_handler ()
347389
@@ -362,22 +404,16 @@ async def on_POST(
362404 if not await self .admin_handler .get_user (target_user ):
363405 raise NotFoundError ("User not found" )
364406
365- if RoomID .is_valid (room_identifier ):
366- room_id = room_identifier
367- try :
368- remote_room_hosts = [
369- x .decode ("ascii" ) for x in request .args [b"server_name" ]
370- ] # type: Optional[List[str]]
371- except Exception :
372- remote_room_hosts = None
373- elif RoomAlias .is_valid (room_identifier ):
374- handler = self .room_member_handler
375- room_alias = RoomAlias .from_string (room_identifier )
376- room_id , remote_room_hosts = await handler .lookup_room_alias (room_alias )
377- else :
378- raise SynapseError (
379- 400 , "%s was not legal room ID or room alias" % (room_identifier ,)
380- )
407+ # Get the room ID from the identifier.
408+ try :
409+ remote_room_hosts = [
410+ x .decode ("ascii" ) for x in request .args [b"server_name" ]
411+ ] # type: Optional[List[str]]
412+ except Exception :
413+ remote_room_hosts = None
414+ room_id , remote_room_hosts = await self .resolve_room_id (
415+ room_identifier , remote_room_hosts
416+ )
381417
382418 fake_requester = create_requester (
383419 target_user , authenticated_entity = requester .authenticated_entity
@@ -412,7 +448,7 @@ async def on_POST(
412448 return 200 , {"room_id" : room_id }
413449
414450
415- class MakeRoomAdminRestServlet (RestServlet ):
451+ class MakeRoomAdminRestServlet (ResolveRoomIdMixin , RestServlet ):
416452 """Allows a server admin to get power in a room if a local user has power in
417453 a room. Will also invite the user if they're not in the room and it's a
418454 private room. Can specify another user (rather than the admin user) to be
@@ -427,29 +463,21 @@ class MakeRoomAdminRestServlet(RestServlet):
427463 PATTERNS = admin_patterns ("/rooms/(?P<room_identifier>[^/]*)/make_room_admin" )
428464
429465 def __init__ (self , hs : "HomeServer" ):
466+ super ().__init__ (hs )
430467 self .hs = hs
431468 self .auth = hs .get_auth ()
432- self .room_member_handler = hs .get_room_member_handler ()
433469 self .event_creation_handler = hs .get_event_creation_handler ()
434470 self .state_handler = hs .get_state_handler ()
435471 self .is_mine_id = hs .is_mine_id
436472
437- async def on_POST (self , request , room_identifier ):
473+ async def on_POST (
474+ self , request : SynapseRequest , room_identifier : str
475+ ) -> Tuple [int , JsonDict ]:
438476 requester = await self .auth .get_user_by_req (request )
439477 await assert_user_is_admin (self .auth , requester .user )
440478 content = parse_json_object_from_request (request , allow_empty_body = True )
441479
442- # Resolve to a room ID, if necessary.
443- if RoomID .is_valid (room_identifier ):
444- room_id = room_identifier
445- elif RoomAlias .is_valid (room_identifier ):
446- room_alias = RoomAlias .from_string (room_identifier )
447- room_id , _ = await self .room_member_handler .lookup_room_alias (room_alias )
448- room_id = room_id .to_string ()
449- else :
450- raise SynapseError (
451- 400 , "%s was not legal room ID or room alias" % (room_identifier ,)
452- )
480+ room_id , _ = await self .resolve_room_id (room_identifier )
453481
454482 # Which user to grant room admin rights to.
455483 user_to_add = content .get ("user_id" , requester .user .to_string ())
@@ -556,7 +584,7 @@ async def on_POST(self, request, room_identifier):
556584 return 200 , {}
557585
558586
559- class ForwardExtremitiesRestServlet (RestServlet ):
587+ class ForwardExtremitiesRestServlet (ResolveRoomIdMixin , RestServlet ):
560588 """Allows a server admin to get or clear forward extremities.
561589
562590 Clearing does not require restarting the server.
@@ -571,43 +599,29 @@ class ForwardExtremitiesRestServlet(RestServlet):
571599 PATTERNS = admin_patterns ("/rooms/(?P<room_identifier>[^/]*)/forward_extremities" )
572600
573601 def __init__ (self , hs : "HomeServer" ):
602+ super ().__init__ (hs )
574603 self .hs = hs
575604 self .auth = hs .get_auth ()
576- self .room_member_handler = hs .get_room_member_handler ()
577605 self .store = hs .get_datastore ()
578606
579- async def resolve_room_id (self , room_identifier : str ) -> str :
580- """Resolve to a room ID, if necessary."""
581- if RoomID .is_valid (room_identifier ):
582- resolved_room_id = room_identifier
583- elif RoomAlias .is_valid (room_identifier ):
584- room_alias = RoomAlias .from_string (room_identifier )
585- room_id , _ = await self .room_member_handler .lookup_room_alias (room_alias )
586- resolved_room_id = room_id .to_string ()
587- else :
588- raise SynapseError (
589- 400 , "%s was not legal room ID or room alias" % (room_identifier ,)
590- )
591- if not resolved_room_id :
592- raise SynapseError (
593- 400 , "Unknown room ID or room alias %s" % room_identifier
594- )
595- return resolved_room_id
596-
597- async def on_DELETE (self , request , room_identifier ):
607+ async def on_DELETE (
608+ self , request : SynapseRequest , room_identifier : str
609+ ) -> Tuple [int , JsonDict ]:
598610 requester = await self .auth .get_user_by_req (request )
599611 await assert_user_is_admin (self .auth , requester .user )
600612
601- room_id = await self .resolve_room_id (room_identifier )
613+ room_id , _ = await self .resolve_room_id (room_identifier )
602614
603615 deleted_count = await self .store .delete_forward_extremities_for_room (room_id )
604616 return 200 , {"deleted" : deleted_count }
605617
606- async def on_GET (self , request , room_identifier ):
618+ async def on_GET (
619+ self , request : SynapseRequest , room_identifier : str
620+ ) -> Tuple [int , JsonDict ]:
607621 requester = await self .auth .get_user_by_req (request )
608622 await assert_user_is_admin (self .auth , requester .user )
609623
610- room_id = await self .resolve_room_id (room_identifier )
624+ room_id , _ = await self .resolve_room_id (room_identifier )
611625
612626 extremities = await self .store .get_forward_extremities_for_room (room_id )
613627 return 200 , {"count" : len (extremities ), "results" : extremities }
@@ -623,14 +637,16 @@ class RoomEventContextServlet(RestServlet):
623637
624638 PATTERNS = admin_patterns ("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" )
625639
626- def __init__ (self , hs ):
640+ def __init__ (self , hs : "HomeServer" ):
627641 super ().__init__ ()
628642 self .clock = hs .get_clock ()
629643 self .room_context_handler = hs .get_room_context_handler ()
630644 self ._event_serializer = hs .get_event_client_serializer ()
631645 self .auth = hs .get_auth ()
632646
633- async def on_GET (self , request , room_id , event_id ):
647+ async def on_GET (
648+ self , request : SynapseRequest , room_id : str , event_id : str
649+ ) -> Tuple [int , JsonDict ]:
634650 requester = await self .auth .get_user_by_req (request , allow_guest = False )
635651 await assert_user_is_admin (self .auth , requester .user )
636652
0 commit comments