5454from synapse .http .server import JsonResource
5555from synapse .module_api import NOT_SPAM , ModuleApi , errors
5656from synapse .server import HomeServer
57- from synapse .storage .database import LoggingTransaction , make_conn
57+ from synapse .storage .database import LoggingTransaction
5858from synapse .types import (
5959 Requester ,
60- RoomID ,
6160 ScheduledTask ,
6261 StateMap ,
6362 TaskStatus ,
8180 MessengerInfoResource ,
8281 MessengerIsInsuranceResource ,
8382)
84- from synapse_invite_checker .store import InviteCheckerStore
8583from synapse_invite_checker .types import (
8684 EpaRoomTimestampResults ,
8785 FederationList ,
@@ -205,8 +203,11 @@ class InviteChecker:
205203
206204 def __init__ (self , config : InviteCheckerConfig , api : ModuleApi ):
207205 self .api = api
208- # Need this for the @measure_func decorator to work
206+ # This needs to be on the Class itself so that metrics functions that measure
207+ # requests and database calls will function. Specifically for @measure_func
208+ self .server_name = api .server_name
209209 self .clock = api ._hs .get_clock ()
210+
210211 self .config = config
211212 # Can not do this as part of parse_config() as there is no access to the server
212213 # name yet
@@ -233,20 +234,6 @@ def __init__(self, config: InviteCheckerConfig, api: ModuleApi):
233234 check_login_for_spam = self .check_login_for_spam
234235 )
235236
236- dbconfig = None
237- for dbconf in api ._store .config .database .databases :
238- if dbconf .name == "master" :
239- dbconfig = dbconf
240-
241- if not dbconfig :
242- msg = "missing database config"
243- raise Exception (msg )
244-
245- with make_conn (
246- dbconfig , api ._store .database_engine , "invite_checker_startup"
247- ) as db_conn :
248- self .store = InviteCheckerStore (api ._store .db_pool , db_conn , api ._hs )
249-
250237 # Make sure this doesn't get initialized until after the default permissions
251238 # were potentially modified to account for the local server template
252239 self .permissions_handler = InviteCheckerPermissionsHandler (
@@ -565,19 +552,70 @@ async def on_upgrade_room(
565552 async def user_may_join_room (
566553 self , user : str , room_id : str , is_invited : bool
567554 ) -> Literal ["NOT_SPAM" ] | errors .Codes :
568- user_domain = UserID .from_string (user ).domain
569- room_domain = RoomID .from_string (room_id ).domain
570- # This only runs for local users, so only try and block remote rooms
571- if user_domain != room_domain :
572- # Block non-invited people from joining this room.
573- if not is_invited :
555+ """
556+ This is used to check that a local user can join a room. Invites to remote
557+ public rooms MUST be denied. Invites to local rooms are allowed(unless it is an
558+ EPA server, in which case it should not get here)
559+ Args:
560+ user:
561+ room_id:
562+ is_invited:
563+
564+ Returns:
565+
566+ """
567+ if not is_invited :
568+ # Do we have the creation event of the room state?
569+ state_mapping : StateMap [EventBase ] = (
570+ await self .api ._storage_controllers .state .get_current_state (
571+ room_id ,
572+ StateFilter .from_types (
573+ [(EventTypes .Create , None ), (EventTypes .JoinRules , None )]
574+ ),
575+ )
576+ )
577+
578+ creation_event = state_mapping .get ((EventTypes .Create , "" ))
579+ if not creation_event :
580+ # This happens because we do not have the state of the room. If this was
581+ # an invite(which includes 'invite_room_state') we would not be here. It
582+ # is highly likely that this means the room is remote. Since remote
583+ # rooms with no invite are not allowed, deny the request. Local rooms
584+ # that do not have state should be in the act of purging, in which case
585+ # we do not want to allow that join anyway.
574586 logger .debug (
575- "Forbidding user (%s) from joining local room (%s )" ,
587+ "Denying join of '%s' to room '%s' because local server has no state(which represents a remote room )" ,
576588 user ,
577589 room_id ,
578590 )
579591 return errors .Codes .FORBIDDEN
580592
593+ # There was no invite, but we already have the state of the room. Deny
594+ # public rooms if the room's creator's domain isn't the same as the local
595+ # server
596+ room_creator = UserID .from_string (creation_event .sender )
597+ if room_creator .domain != self .server_name :
598+ join_rules = state_mapping .get ((EventTypes .JoinRules , "" ))
599+ # all rooms should have join_rules, make sure
600+ if join_rules is None :
601+ logger .warning (
602+ "Room state of '%s' does not contain 'join_rules" , room_id
603+ )
604+ return errors .Codes .FORBIDDEN
605+
606+ if join_rules .content ["join_rule" ] == JoinRules .PUBLIC :
607+ # There are no public remote rooms.
608+ logger .debug (
609+ "Forbidding join of '%s' to remote PUBLIC room '%s'" ,
610+ user ,
611+ room_id ,
612+ )
613+ return errors .Codes .FORBIDDEN
614+
615+ # Room was created by a local user
616+ return NOT_SPAM
617+
618+ else :
581619 # Try and see if the invite event had any initial room state data. For now,
582620 # this requires a database call, but if https://github.com/element-hq/synapse/issues/18230
583621 # becomes a thing, we won't need it anymore. It is possible that room_data
@@ -586,21 +624,42 @@ async def user_may_join_room(
586624 room_data = await self .api ._store .get_invite_for_local_user_in_room (
587625 user , room_id
588626 )
589- if room_data :
590- invite_event = await self .api ._store .get_event (room_data .event_id )
591- for _event in invite_event .unsigned .get ("invite_room_state" , []):
592- if (
593- _event ["type" ] == EventTypes .JoinRules
594- and _event ["content" ]["join_rule" ] == JoinRules .PUBLIC
595- ):
596- return errors .Codes .FORBIDDEN
597- else :
627+ if room_data is None :
628+ # If for some reason this data is missing, deny and bail. Someone is doing
629+ # something fishy
598630 logger .warning (
599- "room_data(RoomsForUser) was None after an invite for user (%s) in room (%s) " ,
631+ "Forbidding join of '%s' to room '%s' because invite data could not be found " ,
600632 user ,
601633 room_id ,
602634 )
603- return NOT_SPAM
635+ return errors .Codes .FORBIDDEN
636+
637+ invite_event = await self .api ._store .get_event (room_data .event_id )
638+
639+ create_event_senders_domain = None
640+ is_public = True
641+
642+ # Sort out the conditions
643+ for _event in invite_event .unsigned .get ("invite_room_state" , []):
644+ if (
645+ _event ["type" ] == EventTypes .JoinRules
646+ and _event ["content" ]["join_rule" ] != JoinRules .PUBLIC
647+ ):
648+ is_public = False
649+ if _event ["type" ] == EventTypes .Create :
650+ create_event_senders_domain = UserID .from_string (
651+ _event ["sender" ]
652+ ).domain
653+
654+ if is_public and create_event_senders_domain != self .server_name :
655+ logger .debug (
656+ "Forbidding joining '%s' to invited room '%s' because room is PUBLIC" ,
657+ user ,
658+ room_id ,
659+ )
660+ return errors .Codes .FORBIDDEN
661+
662+ return NOT_SPAM
604663
605664 async def check_event_allowed (
606665 self , event : EventBase , context : StateMap [EventBase ]
@@ -671,6 +730,7 @@ async def check_event_allowed(
671730 EventContentFields .FEDERATE , True
672731 )
673732 # Remember to account for the override disabler
733+ # TODO: fix this possible reference before assignment
674734 if federated_flag and self .config .override_public_room_federation :
675735 return False , None
676736
@@ -900,7 +960,7 @@ def f(txn: LoggingTransaction) -> set[str]:
900960 txn .execute (sql )
901961 return {room_id for (room_id ,) in txn .fetchall ()}
902962
903- return await self .store .db_pool .runInteraction ("get_rooms" , f )
963+ return await self .api . _store .db_pool .runInteraction ("get_rooms" , f )
904964
905965 @measure_func ("get_timestamps_from_eligible_events_for_epa_room_purge" )
906966 async def get_timestamps_from_eligible_events_for_epa_room_purge (
0 commit comments