103103
104104@attr .s (slots = True )
105105class _NewEventInfo :
106- """Holds information about a received event, ready for passing to _handle_new_events
106+ """Holds information about a received event, ready for passing to _auth_and_persist_events
107107
108108 Attributes:
109109 event: the received event
@@ -807,7 +807,10 @@ async def _process_received_pdu(
807807 logger .debug ("Processing event: %s" , event )
808808
809809 try :
810- await self ._handle_new_event (origin , event , state = state )
810+ context = await self .state_handler .compute_event_context (
811+ event , old_state = state
812+ )
813+ await self ._auth_and_persist_event (origin , event , context , state = state )
811814 except AuthError as e :
812815 raise FederationError ("ERROR" , e .code , e .msg , affected = event .event_id )
813816
@@ -1010,7 +1013,9 @@ async def backfill(
10101013 )
10111014
10121015 if ev_infos :
1013- await self ._handle_new_events (dest , room_id , ev_infos , backfilled = True )
1016+ await self ._auth_and_persist_events (
1017+ dest , room_id , ev_infos , backfilled = True
1018+ )
10141019
10151020 # Step 2: Persist the rest of the events in the chunk one by one
10161021 events .sort (key = lambda e : e .depth )
@@ -1023,10 +1028,12 @@ async def backfill(
10231028 # non-outliers
10241029 assert not event .internal_metadata .is_outlier ()
10251030
1031+ context = await self .state_handler .compute_event_context (event )
1032+
10261033 # We store these one at a time since each event depends on the
10271034 # previous to work out the state.
10281035 # TODO: We can probably do something more clever here.
1029- await self ._handle_new_event (dest , event , backfilled = True )
1036+ await self ._auth_and_persist_event (dest , event , context , backfilled = True )
10301037
10311038 return events
10321039
@@ -1360,7 +1367,7 @@ async def get_event(event_id: str):
13601367
13611368 event_infos .append (_NewEventInfo (event , None , auth ))
13621369
1363- await self ._handle_new_events (
1370+ await self ._auth_and_persist_events (
13641371 destination ,
13651372 room_id ,
13661373 event_infos ,
@@ -1666,10 +1673,11 @@ async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
16661673 # would introduce the danger of backwards-compatibility problems.
16671674 event .internal_metadata .send_on_behalf_of = origin
16681675
1669- context = await self ._handle_new_event (origin , event )
1676+ context = await self .state_handler .compute_event_context (event )
1677+ context = await self ._auth_and_persist_event (origin , event , context )
16701678
16711679 logger .debug (
1672- "on_send_join_request: After _handle_new_event : %s, sigs: %s" ,
1680+ "on_send_join_request: After _auth_and_persist_event : %s, sigs: %s" ,
16731681 event .event_id ,
16741682 event .signatures ,
16751683 )
@@ -1878,10 +1886,11 @@ async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
18781886
18791887 event .internal_metadata .outlier = False
18801888
1881- await self ._handle_new_event (origin , event )
1889+ context = await self .state_handler .compute_event_context (event )
1890+ await self ._auth_and_persist_event (origin , event , context )
18821891
18831892 logger .debug (
1884- "on_send_leave_request: After _handle_new_event : %s, sigs: %s" ,
1893+ "on_send_leave_request: After _auth_and_persist_event : %s, sigs: %s" ,
18851894 event .event_id ,
18861895 event .signatures ,
18871896 )
@@ -1989,16 +1998,47 @@ async def get_persisted_pdu(
19891998 async def get_min_depth_for_context (self , context : str ) -> int :
19901999 return await self .store .get_min_depth (context )
19912000
1992- async def _handle_new_event (
2001+ async def _auth_and_persist_event (
19932002 self ,
19942003 origin : str ,
19952004 event : EventBase ,
2005+ context : EventContext ,
19962006 state : Optional [Iterable [EventBase ]] = None ,
19972007 auth_events : Optional [MutableStateMap [EventBase ]] = None ,
19982008 backfilled : bool = False ,
19992009 ) -> EventContext :
2000- context = await self ._prep_event (
2001- origin , event , state = state , auth_events = auth_events , backfilled = backfilled
2010+ """
2011+ Process an event by performing auth checks and then persisting to the database.
2012+
2013+ Args:
2014+ origin: The host the event originates from.
2015+ event: The event itself.
2016+ context:
2017+ The event context.
2018+
2019+ NB that this function potentially modifies it.
2020+ state:
2021+ The state events used to check the event for soft-fail. If this is
2022+ not provided the current state events will be used.
2023+ auth_events:
2024+ Map from (event_type, state_key) to event
2025+
2026+ Normally, our calculated auth_events based on the state of the room
2027+ at the event's position in the DAG, though occasionally (eg if the
2028+ event is an outlier), may be the auth events claimed by the remote
2029+ server.
2030+ backfilled: True if the event was backfilled.
2031+
2032+ Returns:
2033+ The event context.
2034+ """
2035+ context = await self ._check_event_auth (
2036+ origin ,
2037+ event ,
2038+ context ,
2039+ state = state ,
2040+ auth_events = auth_events ,
2041+ backfilled = backfilled ,
20022042 )
20032043
20042044 try :
@@ -2022,7 +2062,7 @@ async def _handle_new_event(
20222062
20232063 return context
20242064
2025- async def _handle_new_events (
2065+ async def _auth_and_persist_events (
20262066 self ,
20272067 origin : str ,
20282068 room_id : str ,
@@ -2040,9 +2080,13 @@ async def _handle_new_events(
20402080 async def prep (ev_info : _NewEventInfo ):
20412081 event = ev_info .event
20422082 with nested_logging_context (suffix = event .event_id ):
2043- res = await self ._prep_event (
2083+ res = await self .state_handler .compute_event_context (
2084+ event , old_state = ev_info .state
2085+ )
2086+ res = await self ._check_event_auth (
20442087 origin ,
20452088 event ,
2089+ res ,
20462090 state = ev_info .state ,
20472091 auth_events = ev_info .auth_events ,
20482092 backfilled = backfilled ,
@@ -2177,49 +2221,6 @@ async def _persist_auth_tree(
21772221 room_id , [(event , new_event_context )]
21782222 )
21792223
2180- async def _prep_event (
2181- self ,
2182- origin : str ,
2183- event : EventBase ,
2184- state : Optional [Iterable [EventBase ]],
2185- auth_events : Optional [MutableStateMap [EventBase ]],
2186- backfilled : bool ,
2187- ) -> EventContext :
2188- context = await self .state_handler .compute_event_context (event , old_state = state )
2189-
2190- if not auth_events :
2191- prev_state_ids = await context .get_prev_state_ids ()
2192- auth_events_ids = self .auth .compute_auth_events (
2193- event , prev_state_ids , for_verification = True
2194- )
2195- auth_events_x = await self .store .get_events (auth_events_ids )
2196- auth_events = {(e .type , e .state_key ): e for e in auth_events_x .values ()}
2197-
2198- # This is a hack to fix some old rooms where the initial join event
2199- # didn't reference the create event in its auth events.
2200- if event .type == EventTypes .Member and not event .auth_event_ids ():
2201- if len (event .prev_event_ids ()) == 1 and event .depth < 5 :
2202- c = await self .store .get_event (
2203- event .prev_event_ids ()[0 ], allow_none = True
2204- )
2205- if c and c .type == EventTypes .Create :
2206- auth_events [(c .type , c .state_key )] = c
2207-
2208- context = await self .do_auth (origin , event , context , auth_events = auth_events )
2209-
2210- if not context .rejected :
2211- await self ._check_for_soft_fail (event , state , backfilled )
2212-
2213- if event .type == EventTypes .GuestAccess and not context .rejected :
2214- await self .maybe_kick_guest_users (event )
2215-
2216- # If we are going to send this event over federation we precaclculate
2217- # the joined hosts.
2218- if event .internal_metadata .get_send_on_behalf_of ():
2219- await self .event_creation_handler .cache_joined_hosts_for_event (event )
2220-
2221- return context
2222-
22232224 async def _check_for_soft_fail (
22242225 self , event : EventBase , state : Optional [Iterable [EventBase ]], backfilled : bool
22252226 ) -> None :
@@ -2330,19 +2331,28 @@ async def on_get_missing_events(
23302331
23312332 return missing_events
23322333
2333- async def do_auth (
2334+ async def _check_event_auth (
23342335 self ,
23352336 origin : str ,
23362337 event : EventBase ,
23372338 context : EventContext ,
2338- auth_events : MutableStateMap [EventBase ],
2339+ state : Optional [Iterable [EventBase ]],
2340+ auth_events : Optional [MutableStateMap [EventBase ]],
2341+ backfilled : bool ,
23392342 ) -> EventContext :
23402343 """
2344+ Checks whether an event should be rejected (for failing auth checks).
23412345
23422346 Args:
2343- origin:
2344- event:
2347+ origin: The host the event originates from.
2348+ event: The event itself.
23452349 context:
2350+ The event context.
2351+
2352+ NB that this function potentially modifies it.
2353+ state:
2354+ The state events used to check the event for soft-fail. If this is
2355+ not provided the current state events will be used.
23462356 auth_events:
23472357 Map from (event_type, state_key) to event
23482358
@@ -2352,12 +2362,34 @@ async def do_auth(
23522362 server.
23532363
23542364 Also NB that this function adds entries to it.
2365+
2366+ If this is not provided, it is calculated from the previous state IDs.
2367+ backfilled: True if the event was backfilled.
2368+
23552369 Returns:
2356- updated context object
2370+ The updated context object.
23572371 """
23582372 room_version = await self .store .get_room_version_id (event .room_id )
23592373 room_version_obj = KNOWN_ROOM_VERSIONS [room_version ]
23602374
2375+ if not auth_events :
2376+ prev_state_ids = await context .get_prev_state_ids ()
2377+ auth_events_ids = self .auth .compute_auth_events (
2378+ event , prev_state_ids , for_verification = True
2379+ )
2380+ auth_events_x = await self .store .get_events (auth_events_ids )
2381+ auth_events = {(e .type , e .state_key ): e for e in auth_events_x .values ()}
2382+
2383+ # This is a hack to fix some old rooms where the initial join event
2384+ # didn't reference the create event in its auth events.
2385+ if event .type == EventTypes .Member and not event .auth_event_ids ():
2386+ if len (event .prev_event_ids ()) == 1 and event .depth < 5 :
2387+ c = await self .store .get_event (
2388+ event .prev_event_ids ()[0 ], allow_none = True
2389+ )
2390+ if c and c .type == EventTypes .Create :
2391+ auth_events [(c .type , c .state_key )] = c
2392+
23612393 try :
23622394 context = await self ._update_auth_events_and_context_for_auth (
23632395 origin , event , context , auth_events
@@ -2379,6 +2411,17 @@ async def do_auth(
23792411 logger .warning ("Failed auth resolution for %r because %s" , event , e )
23802412 context .rejected = RejectedReason .AUTH_ERROR
23812413
2414+ if not context .rejected :
2415+ await self ._check_for_soft_fail (event , state , backfilled )
2416+
2417+ if event .type == EventTypes .GuestAccess and not context .rejected :
2418+ await self .maybe_kick_guest_users (event )
2419+
2420+ # If we are going to send this event over federation we precaclculate
2421+ # the joined hosts.
2422+ if event .internal_metadata .get_send_on_behalf_of ():
2423+ await self .event_creation_handler .cache_joined_hosts_for_event (event )
2424+
23822425 return context
23832426
23842427 async def _update_auth_events_and_context_for_auth (
@@ -2388,7 +2431,7 @@ async def _update_auth_events_and_context_for_auth(
23882431 context : EventContext ,
23892432 auth_events : MutableStateMap [EventBase ],
23902433 ) -> EventContext :
2391- """Helper for do_auth . See there for docs.
2434+ """Helper for _check_event_auth . See there for docs.
23922435
23932436 Checks whether a given event has the expected auth events. If it
23942437 doesn't then we talk to the remote server to compare state to see if
@@ -2468,9 +2511,14 @@ async def _update_auth_events_and_context_for_auth(
24682511 e .internal_metadata .outlier = True
24692512
24702513 logger .debug (
2471- "do_auth %s missing_auth: %s" , event .event_id , e .event_id
2514+ "_check_event_auth %s missing_auth: %s" ,
2515+ event .event_id ,
2516+ e .event_id ,
2517+ )
2518+ context = await self .state_handler .compute_event_context (e )
2519+ await self ._auth_and_persist_event (
2520+ origin , e , context , auth_events = auth
24722521 )
2473- await self ._handle_new_event (origin , e , auth_events = auth )
24742522
24752523 if e .event_id in event_auth_events :
24762524 auth_events [(e .type , e .state_key )] = e
0 commit comments