@@ -106,6 +106,10 @@ def __init__(self, hs: "HomeServer"):
106106 self .store = hs .get_datastore ()
107107 self .auth = hs .get_auth ()
108108
109+ # Used by `RulesForRoom` to ensure only one thing mutates the cache at a
110+ # time. Keyed off room_id.
111+ self ._rules_linearizer = Linearizer (name = "rules_for_room" )
112+
109113 self .room_push_rule_cache_metrics = register_cache (
110114 "cache" ,
111115 "room_push_rule_cache" ,
@@ -123,7 +127,16 @@ async def _get_rules_for_event(
123127 dict of user_id -> push_rules
124128 """
125129 room_id = event .room_id
126- rules_for_room = self ._get_rules_for_room (room_id )
130+
131+ rules_for_room_data = self ._get_rules_for_room (room_id )
132+ rules_for_room = RulesForRoom (
133+ hs = self .hs ,
134+ room_id = room_id ,
135+ rules_for_room_cache = self ._get_rules_for_room .cache ,
136+ room_push_rule_cache_metrics = self .room_push_rule_cache_metrics ,
137+ linearizer = self ._rules_linearizer ,
138+ cached_data = rules_for_room_data ,
139+ )
127140
128141 rules_by_user = await rules_for_room .get_rules (event , context )
129142
@@ -142,17 +155,12 @@ async def _get_rules_for_event(
142155 return rules_by_user
143156
144157 @lru_cache ()
145- def _get_rules_for_room (self , room_id : str ) -> "RulesForRoom " :
146- """Get the current RulesForRoom object for the given room id"""
147- # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
158+ def _get_rules_for_room (self , room_id : str ) -> "RulesForRoomData " :
159+ """Get the current RulesForRoomData object for the given room id"""
160+ # It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache
148161 # before any lookup methods get called on it as otherwise there may be
149162 # a race if invalidate_all gets called (which assumes its in the cache)
150- return RulesForRoom (
151- self .hs ,
152- room_id ,
153- self ._get_rules_for_room .cache ,
154- self .room_push_rule_cache_metrics ,
155- )
163+ return RulesForRoomData ()
156164
157165 async def _get_power_levels_and_sender_level (
158166 self , event : EventBase , context : EventContext
@@ -282,11 +290,49 @@ def _condition_checker(
282290 return True
283291
284292
293+ @attr .s (slots = True )
294+ class RulesForRoomData :
295+ """The data stored in the cache by `RulesForRoom`.
296+
297+ We don't store `RulesForRoom` directly in the cache as we want our caches to
298+ *only* include data, and not references to e.g. the data stores.
299+ """
300+
301+ # event_id -> (user_id, state)
302+ member_map = attr .ib (type = Dict [str , Tuple [str , str ]], factory = dict )
303+ # user_id -> rules
304+ rules_by_user = attr .ib (type = Dict [str , List [Dict [str , dict ]]], factory = dict )
305+
306+ # The last state group we updated the caches for. If the state_group of
307+ # a new event comes along, we know that we can just return the cached
308+ # result.
309+ # On invalidation of the rules themselves (if the user changes them),
310+ # we invalidate everything and set state_group to `object()`
311+ state_group = attr .ib (type = Union [object , int ], factory = object )
312+
313+ # A sequence number to keep track of when we're allowed to update the
314+ # cache. We bump the sequence number when we invalidate the cache. If
315+ # the sequence number changes while we're calculating stuff we should
316+ # not update the cache with it.
317+ sequence = attr .ib (type = int , default = 0 )
318+
319+ # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
320+ # owned by AS's, or remote users, etc. (I.e. users we will never need to
321+ # calculate push for)
322+ # These never need to be invalidated as we will never set up push for
323+ # them.
324+ uninteresting_user_set = attr .ib (type = Set [str ], factory = set )
325+
326+
285327class RulesForRoom :
286328 """Caches push rules for users in a room.
287329
288330 This efficiently handles users joining/leaving the room by not invalidating
289331 the entire cache for the room.
332+
333+ A new instance is constructed for each call to
334+ `BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from
335+ previous calls passed in.
290336 """
291337
292338 def __init__ (
@@ -295,6 +341,8 @@ def __init__(
295341 room_id : str ,
296342 rules_for_room_cache : LruCache ,
297343 room_push_rule_cache_metrics : CacheMetric ,
344+ linearizer : Linearizer ,
345+ cached_data : RulesForRoomData ,
298346 ):
299347 """
300348 Args:
@@ -303,38 +351,21 @@ def __init__(
303351 rules_for_room_cache: The cache object that caches these
304352 RoomsForUser objects.
305353 room_push_rule_cache_metrics: The metrics object
354+ linearizer: The linearizer used to ensure only one thing mutates
355+ the cache at a time. Keyed off room_id
356+ cached_data: Cached data from previous calls to `self.get_rules`,
357+ can be mutated.
306358 """
307359 self .room_id = room_id
308360 self .is_mine_id = hs .is_mine_id
309361 self .store = hs .get_datastore ()
310362 self .room_push_rule_cache_metrics = room_push_rule_cache_metrics
311363
312- self .linearizer = Linearizer (name = "rules_for_room" )
313-
314- # event_id -> (user_id, state)
315- self .member_map = {} # type: Dict[str, Tuple[str, str]]
316- # user_id -> rules
317- self .rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
318-
319- # The last state group we updated the caches for. If the state_group of
320- # a new event comes along, we know that we can just return the cached
321- # result.
322- # On invalidation of the rules themselves (if the user changes them),
323- # we invalidate everything and set state_group to `object()`
324- self .state_group = object ()
325-
326- # A sequence number to keep track of when we're allowed to update the
327- # cache. We bump the sequence number when we invalidate the cache. If
328- # the sequence number changes while we're calculating stuff we should
329- # not update the cache with it.
330- self .sequence = 0
331-
332- # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
333- # owned by AS's, or remote users, etc. (I.e. users we will never need to
334- # calculate push for)
335- # These never need to be invalidated as we will never set up push for
336- # them.
337- self .uninteresting_user_set = set () # type: Set[str]
364+ # Used to ensure only one thing mutates the cache at a time. Keyed off
365+ # room_id.
366+ self .linearizer = linearizer
367+
368+ self .data = cached_data
338369
339370 # We need to be clever on the invalidating caches callbacks, as
340371 # otherwise the invalidation callback holds a reference to the object,
@@ -352,25 +383,25 @@ async def get_rules(
352383 """
353384 state_group = context .state_group
354385
355- if state_group and self .state_group == state_group :
386+ if state_group and self .data . state_group == state_group :
356387 logger .debug ("Using cached rules for %r" , self .room_id )
357388 self .room_push_rule_cache_metrics .inc_hits ()
358- return self .rules_by_user
389+ return self .data . rules_by_user
359390
360- with (await self .linearizer .queue (() )):
361- if state_group and self .state_group == state_group :
391+ with (await self .linearizer .queue (self . room_id )):
392+ if state_group and self .data . state_group == state_group :
362393 logger .debug ("Using cached rules for %r" , self .room_id )
363394 self .room_push_rule_cache_metrics .inc_hits ()
364- return self .rules_by_user
395+ return self .data . rules_by_user
365396
366397 self .room_push_rule_cache_metrics .inc_misses ()
367398
368399 ret_rules_by_user = {}
369400 missing_member_event_ids = {}
370- if state_group and self .state_group == context .prev_group :
401+ if state_group and self .data . state_group == context .prev_group :
371402 # If we have a simple delta then we can reuse most of the previous
372403 # results.
373- ret_rules_by_user = self .rules_by_user
404+ ret_rules_by_user = self .data . rules_by_user
374405 current_state_ids = context .delta_ids
375406
376407 push_rules_delta_state_cache_metric .inc_hits ()
@@ -393,24 +424,24 @@ async def get_rules(
393424 if typ != EventTypes .Member :
394425 continue
395426
396- if user_id in self .uninteresting_user_set :
427+ if user_id in self .data . uninteresting_user_set :
397428 continue
398429
399430 if not self .is_mine_id (user_id ):
400- self .uninteresting_user_set .add (user_id )
431+ self .data . uninteresting_user_set .add (user_id )
401432 continue
402433
403434 if self .store .get_if_app_services_interested_in_user (user_id ):
404- self .uninteresting_user_set .add (user_id )
435+ self .data . uninteresting_user_set .add (user_id )
405436 continue
406437
407438 event_id = current_state_ids [key ]
408439
409- res = self .member_map .get (event_id , None )
440+ res = self .data . member_map .get (event_id , None )
410441 if res :
411442 user_id , state = res
412443 if state == Membership .JOIN :
413- rules = self .rules_by_user .get (user_id , None )
444+ rules = self .data . rules_by_user .get (user_id , None )
414445 if rules :
415446 ret_rules_by_user [user_id ] = rules
416447 continue
@@ -430,7 +461,7 @@ async def get_rules(
430461 else :
431462 # The push rules didn't change but lets update the cache anyway
432463 self .update_cache (
433- self .sequence ,
464+ self .data . sequence ,
434465 members = {}, # There were no membership changes
435466 rules_by_user = ret_rules_by_user ,
436467 state_group = state_group ,
@@ -461,7 +492,7 @@ async def _update_rules_with_member_event_ids(
461492 for. Used when updating the cache.
462493 event: The event we are currently computing push rules for.
463494 """
464- sequence = self .sequence
495+ sequence = self .data . sequence
465496
466497 rows = await self .store .get_membership_from_event_ids (member_event_ids .values ())
467498
@@ -501,23 +532,11 @@ async def _update_rules_with_member_event_ids(
501532
502533 self .update_cache (sequence , members , ret_rules_by_user , state_group )
503534
504- def invalidate_all (self ) -> None :
505- # Note: Don't hand this function directly to an invalidation callback
506- # as it keeps a reference to self and will stop this instance from being
507- # GC'd if it gets dropped from the rules_to_user cache. Instead use
508- # `self.invalidate_all_cb`
509- logger .debug ("Invalidating RulesForRoom for %r" , self .room_id )
510- self .sequence += 1
511- self .state_group = object ()
512- self .member_map = {}
513- self .rules_by_user = {}
514- push_rules_invalidation_counter .inc ()
515-
516535 def update_cache (self , sequence , members , rules_by_user , state_group ) -> None :
517- if sequence == self .sequence :
518- self .member_map .update (members )
519- self .rules_by_user = rules_by_user
520- self .state_group = state_group
536+ if sequence == self .data . sequence :
537+ self .data . member_map .update (members )
538+ self .data . rules_by_user = rules_by_user
539+ self .data . state_group = state_group
521540
522541
523542@attr .attrs (slots = True , frozen = True )
@@ -535,6 +554,10 @@ class _Invalidation:
535554 room_id = attr .ib (type = str )
536555
537556 def __call__ (self ) -> None :
538- rules = self .cache .get (self .room_id , None , update_metrics = False )
539- if rules :
540- rules .invalidate_all ()
557+ rules_data = self .cache .get (self .room_id , None , update_metrics = False )
558+ if rules_data :
559+ rules_data .sequence += 1
560+ rules_data .state_group = object ()
561+ rules_data .member_map = {}
562+ rules_data .rules_by_user = {}
563+ push_rules_invalidation_counter .inc ()
0 commit comments