3131 Sequence ,
3232 Set ,
3333 Tuple ,
34- Union ,
3534)
3635
3736import attr
4746from synapse .state import v1 , v2
4847from synapse .storage .databases .main .events_worker import EventRedactBehaviour
4948from synapse .storage .roommember import ProfileInfo
49+ from synapse .storage .state import StateFilter
5050from synapse .types import StateMap
5151from synapse .util .async_helpers import Linearizer
5252from synapse .util .caches .expiringcache import ExpiringCache
5353from synapse .util .metrics import Measure , measure_func
5454
5555if TYPE_CHECKING :
5656 from synapse .server import HomeServer
57+ from synapse .storage .controllers import StateStorageController
5758 from synapse .storage .databases .main import DataStore
5859
5960logger = logging .getLogger (__name__ )
@@ -83,17 +84,20 @@ def _gen_state_id() -> str:
8384
8485
8586class _StateCacheEntry :
86- __slots__ = ["state" , "state_group" , "state_id" , " prev_group" , "delta_ids" ]
87+ __slots__ = ["state" , "state_group" , "prev_group" , "delta_ids" ]
8788
8889 def __init__ (
8990 self ,
90- state : StateMap [str ],
91+ state : Optional [ StateMap [str ] ],
9192 state_group : Optional [int ],
9293 prev_group : Optional [int ] = None ,
9394 delta_ids : Optional [StateMap [str ]] = None ,
9495 ):
96+ if state is None and state_group is None :
97+ raise Exception ("Either state or state group must be not None" )
98+
9599 # A map from (type, state_key) to event_id.
96- self .state = frozendict (state )
100+ self .state = frozendict (state ) if state is not None else None
97101
98102 # the ID of a state group if one and only one is involved.
99103 # otherwise, None otherwise?
@@ -102,20 +106,30 @@ def __init__(
102106 self .prev_group = prev_group
103107 self .delta_ids = frozendict (delta_ids ) if delta_ids is not None else None
104108
105- # The `state_id` is a unique ID we generate that can be used as ID for
106- # this collection of state. Usually this would be the same as the
107- # state group, but on worker instances we can't generate a new state
108- # group each time we resolve state, so we generate a separate one that
109- # isn't persisted and is used solely for caches.
110- # `state_id` is either a state_group (and so an int) or a string. This
111- # ensures we don't accidentally persist a state_id as a stateg_group
112- if state_group :
113- self .state_id : Union [str , int ] = state_group
114- else :
115- self .state_id = _gen_state_id ()
109+ async def get_state (
110+ self ,
111+ state_storage : "StateStorageController" ,
112+ state_filter : Optional ["StateFilter" ] = None ,
113+ ) -> StateMap [str ]:
114+ """Get the state map for this entry, either from the in-memory state or
115+ looking up the state group in the DB.
116+ """
117+
118+ if self .state is not None :
119+ return self .state
120+
121+ assert self .state_group is not None
122+
123+ return await state_storage .get_state_ids_for_group (
124+ self .state_group , state_filter
125+ )
116126
117127 def __len__ (self ) -> int :
118- return len (self .state )
128+ # The len should is used to estimate how large this cache entry is, for
129+ # cache eviction purposes. This is why if `self.state` is None it's fine
130+ # to return 1.
131+
132+ return len (self .state ) if self .state else 1
119133
120134
121135class StateHandler :
@@ -153,7 +167,7 @@ async def get_current_state_ids(
153167 """
154168 logger .debug ("calling resolve_state_groups from get_current_state_ids" )
155169 ret = await self .resolve_state_groups_for_events (room_id , latest_event_ids )
156- return ret .state
170+ return await ret .get_state ( self . _state_storage_controller , StateFilter . all ())
157171
158172 async def get_current_users_in_room (
159173 self , room_id : str , latest_event_ids : List [str ]
@@ -177,7 +191,8 @@ async def get_current_users_in_room(
177191
178192 logger .debug ("calling resolve_state_groups from get_current_users_in_room" )
179193 entry = await self .resolve_state_groups_for_events (room_id , latest_event_ids )
180- return await self .store .get_joined_users_from_state (room_id , entry )
194+ state = await entry .get_state (self ._state_storage_controller , StateFilter .all ())
195+ return await self .store .get_joined_users_from_state (room_id , state , entry )
181196
182197 async def get_hosts_in_room_at_events (
183198 self , room_id : str , event_ids : Collection [str ]
@@ -192,7 +207,8 @@ async def get_hosts_in_room_at_events(
192207 The hosts in the room at the given events
193208 """
194209 entry = await self .resolve_state_groups_for_events (room_id , event_ids )
195- return await self .store .get_joined_hosts (room_id , entry )
210+ state = await entry .get_state (self ._state_storage_controller , StateFilter .all ())
211+ return await self .store .get_joined_hosts (room_id , state , entry )
196212
197213 async def compute_event_context (
198214 self ,
@@ -227,10 +243,19 @@ async def compute_event_context(
227243 #
228244 if state_ids_before_event :
229245 # if we're given the state before the event, then we use that
230- state_group_before_event = None
231246 state_group_before_event_prev_group = None
232247 deltas_to_state_group_before_event = None
233- entry = None
248+
249+ # .. though we need to get a state group for it.
250+ state_group_before_event = (
251+ await self ._state_storage_controller .store_state_group (
252+ event .event_id ,
253+ event .room_id ,
254+ prev_group = None ,
255+ delta_ids = None ,
256+ current_state_ids = state_ids_before_event ,
257+ )
258+ )
234259
235260 else :
236261 # otherwise, we'll need to resolve the state across the prev_events.
@@ -264,36 +289,27 @@ async def compute_event_context(
264289 await_full_state = False ,
265290 )
266291
267- state_ids_before_event = entry .state
268- state_group_before_event = entry .state_group
269292 state_group_before_event_prev_group = entry .prev_group
270293 deltas_to_state_group_before_event = entry .delta_ids
271294
272- #
273- # make sure that we have a state group at that point. If it's not a state event,
274- # that will be the state group for the new event. If it *is* a state event,
275- # it might get rejected (in which case we'll need to persist it with the
276- # previous state group)
277- #
278-
279- if not state_group_before_event :
280- state_group_before_event = (
281- await self ._state_storage_controller .store_state_group (
282- event .event_id ,
283- event .room_id ,
284- prev_group = state_group_before_event_prev_group ,
285- delta_ids = deltas_to_state_group_before_event ,
286- current_state_ids = state_ids_before_event ,
295+ # We make sure that we have a state group assigned to the state.
296+ if entry .state_group is None :
297+ state_ids_before_event = await entry .get_state (
298+ self ._state_storage_controller , StateFilter .all ()
299+ )
300+ state_group_before_event = (
301+ await self ._state_storage_controller .store_state_group (
302+ event .event_id ,
303+ event .room_id ,
304+ prev_group = state_group_before_event_prev_group ,
305+ delta_ids = deltas_to_state_group_before_event ,
306+ current_state_ids = state_ids_before_event ,
307+ )
287308 )
288- )
289-
290- # Assign the new state group to the cached state entry.
291- #
292- # Note that this can race in that we could generate multiple state
293- # groups for the same state entry, but that is just inefficient
294- # rather than dangerous.
295- if entry and entry .state_group is None :
296309 entry .state_group = state_group_before_event
310+ else :
311+ state_group_before_event = entry .state_group
312+ state_ids_before_event = None
297313
298314 #
299315 # now if it's not a state event, we're done
@@ -313,6 +329,10 @@ async def compute_event_context(
313329 #
314330 # otherwise, we'll need to create a new state group for after the event
315331 #
332+ if state_ids_before_event is None :
333+ state_ids_before_event = await entry .get_state (
334+ self ._state_storage_controller , StateFilter .all ()
335+ )
316336
317337 key = (event .type , event .state_key )
318338 if key in state_ids_before_event :
@@ -372,17 +392,14 @@ async def resolve_state_groups_for_events(
372392 state_group_ids_set = set (state_group_ids )
373393 if len (state_group_ids_set ) == 1 :
374394 (state_group_id ,) = state_group_ids_set
375- state = await self ._state_storage_controller .get_state_for_groups (
376- state_group_ids_set
377- )
378395 (
379396 prev_group ,
380397 delta_ids ,
381398 ) = await self ._state_storage_controller .get_state_group_delta (
382399 state_group_id
383400 )
384401 return _StateCacheEntry (
385- state = state [ state_group_id ] ,
402+ state = None ,
386403 state_group = state_group_id ,
387404 prev_group = prev_group ,
388405 delta_ids = delta_ids ,
0 commit comments