Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 0ca4172

Browse files
authored
Don't pull out state in compute_event_context for unconflicted state (#13267)
1 parent 599c403 commit 0ca4172

File tree

7 files changed

+95
-136
lines changed

7 files changed

+95
-136
lines changed

changelog.d/13267.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Don't pull out state in `compute_event_context` for unconflicted state.

synapse/handlers/message.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,12 @@ async def cache_joined_hosts_for_event(
14441444
if state_entry.state_group in self._external_cache_joined_hosts_updates:
14451445
return
14461446

1447-
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
1447+
state = await state_entry.get_state(
1448+
self._storage_controllers.state, StateFilter.all()
1449+
)
1450+
joined_hosts = await self.store.get_joined_hosts(
1451+
event.room_id, state, state_entry
1452+
)
14481453

14491454
# Note that the expiry times must be larger than the expiry time in
14501455
# _external_cache_joined_hosts_updates.

synapse/state/__init__.py

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
Sequence,
3232
Set,
3333
Tuple,
34-
Union,
3534
)
3635

3736
import attr
@@ -47,13 +46,15 @@
4746
from synapse.state import v1, v2
4847
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
4948
from synapse.storage.roommember import ProfileInfo
49+
from synapse.storage.state import StateFilter
5050
from synapse.types import StateMap
5151
from synapse.util.async_helpers import Linearizer
5252
from synapse.util.caches.expiringcache import ExpiringCache
5353
from synapse.util.metrics import Measure, measure_func
5454

5555
if TYPE_CHECKING:
5656
from synapse.server import HomeServer
57+
from synapse.storage.controllers import StateStorageController
5758
from synapse.storage.databases.main import DataStore
5859

5960
logger = logging.getLogger(__name__)
@@ -83,17 +84,20 @@ def _gen_state_id() -> str:
8384

8485

8586
class _StateCacheEntry:
86-
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
87+
__slots__ = ["state", "state_group", "prev_group", "delta_ids"]
8788

8889
def __init__(
8990
self,
90-
state: StateMap[str],
91+
state: Optional[StateMap[str]],
9192
state_group: Optional[int],
9293
prev_group: Optional[int] = None,
9394
delta_ids: Optional[StateMap[str]] = None,
9495
):
96+
if state is None and state_group is None:
97+
raise Exception("Either state or state group must be not None")
98+
9599
# A map from (type, state_key) to event_id.
96-
self.state = frozendict(state)
100+
self.state = frozendict(state) if state is not None else None
97101

98102
# the ID of a state group if one and only one is involved.
99103
# otherwise, None otherwise?
@@ -102,20 +106,30 @@ def __init__(
102106
self.prev_group = prev_group
103107
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
104108

105-
# The `state_id` is a unique ID we generate that can be used as ID for
106-
# this collection of state. Usually this would be the same as the
107-
# state group, but on worker instances we can't generate a new state
108-
# group each time we resolve state, so we generate a separate one that
109-
# isn't persisted and is used solely for caches.
110-
# `state_id` is either a state_group (and so an int) or a string. This
111-
# ensures we don't accidentally persist a state_id as a stateg_group
112-
if state_group:
113-
self.state_id: Union[str, int] = state_group
114-
else:
115-
self.state_id = _gen_state_id()
109+
async def get_state(
110+
self,
111+
state_storage: "StateStorageController",
112+
state_filter: Optional["StateFilter"] = None,
113+
) -> StateMap[str]:
114+
"""Get the state map for this entry, either from the in-memory state or
115+
looking up the state group in the DB.
116+
"""
117+
118+
if self.state is not None:
119+
return self.state
120+
121+
assert self.state_group is not None
122+
123+
return await state_storage.get_state_ids_for_group(
124+
self.state_group, state_filter
125+
)
116126

117127
def __len__(self) -> int:
118-
return len(self.state)
128+
# The len should is used to estimate how large this cache entry is, for
129+
# cache eviction purposes. This is why if `self.state` is None it's fine
130+
# to return 1.
131+
132+
return len(self.state) if self.state else 1
119133

120134

121135
class StateHandler:
@@ -153,7 +167,7 @@ async def get_current_state_ids(
153167
"""
154168
logger.debug("calling resolve_state_groups from get_current_state_ids")
155169
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
156-
return ret.state
170+
return await ret.get_state(self._state_storage_controller, StateFilter.all())
157171

158172
async def get_current_users_in_room(
159173
self, room_id: str, latest_event_ids: List[str]
@@ -177,7 +191,8 @@ async def get_current_users_in_room(
177191

178192
logger.debug("calling resolve_state_groups from get_current_users_in_room")
179193
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
180-
return await self.store.get_joined_users_from_state(room_id, entry)
194+
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
195+
return await self.store.get_joined_users_from_state(room_id, state, entry)
181196

182197
async def get_hosts_in_room_at_events(
183198
self, room_id: str, event_ids: Collection[str]
@@ -192,7 +207,8 @@ async def get_hosts_in_room_at_events(
192207
The hosts in the room at the given events
193208
"""
194209
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
195-
return await self.store.get_joined_hosts(room_id, entry)
210+
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
211+
return await self.store.get_joined_hosts(room_id, state, entry)
196212

197213
async def compute_event_context(
198214
self,
@@ -227,10 +243,19 @@ async def compute_event_context(
227243
#
228244
if state_ids_before_event:
229245
# if we're given the state before the event, then we use that
230-
state_group_before_event = None
231246
state_group_before_event_prev_group = None
232247
deltas_to_state_group_before_event = None
233-
entry = None
248+
249+
# .. though we need to get a state group for it.
250+
state_group_before_event = (
251+
await self._state_storage_controller.store_state_group(
252+
event.event_id,
253+
event.room_id,
254+
prev_group=None,
255+
delta_ids=None,
256+
current_state_ids=state_ids_before_event,
257+
)
258+
)
234259

235260
else:
236261
# otherwise, we'll need to resolve the state across the prev_events.
@@ -264,36 +289,27 @@ async def compute_event_context(
264289
await_full_state=False,
265290
)
266291

267-
state_ids_before_event = entry.state
268-
state_group_before_event = entry.state_group
269292
state_group_before_event_prev_group = entry.prev_group
270293
deltas_to_state_group_before_event = entry.delta_ids
271294

272-
#
273-
# make sure that we have a state group at that point. If it's not a state event,
274-
# that will be the state group for the new event. If it *is* a state event,
275-
# it might get rejected (in which case we'll need to persist it with the
276-
# previous state group)
277-
#
278-
279-
if not state_group_before_event:
280-
state_group_before_event = (
281-
await self._state_storage_controller.store_state_group(
282-
event.event_id,
283-
event.room_id,
284-
prev_group=state_group_before_event_prev_group,
285-
delta_ids=deltas_to_state_group_before_event,
286-
current_state_ids=state_ids_before_event,
295+
# We make sure that we have a state group assigned to the state.
296+
if entry.state_group is None:
297+
state_ids_before_event = await entry.get_state(
298+
self._state_storage_controller, StateFilter.all()
299+
)
300+
state_group_before_event = (
301+
await self._state_storage_controller.store_state_group(
302+
event.event_id,
303+
event.room_id,
304+
prev_group=state_group_before_event_prev_group,
305+
delta_ids=deltas_to_state_group_before_event,
306+
current_state_ids=state_ids_before_event,
307+
)
287308
)
288-
)
289-
290-
# Assign the new state group to the cached state entry.
291-
#
292-
# Note that this can race in that we could generate multiple state
293-
# groups for the same state entry, but that is just inefficient
294-
# rather than dangerous.
295-
if entry and entry.state_group is None:
296309
entry.state_group = state_group_before_event
310+
else:
311+
state_group_before_event = entry.state_group
312+
state_ids_before_event = None
297313

298314
#
299315
# now if it's not a state event, we're done
@@ -313,6 +329,10 @@ async def compute_event_context(
313329
#
314330
# otherwise, we'll need to create a new state group for after the event
315331
#
332+
if state_ids_before_event is None:
333+
state_ids_before_event = await entry.get_state(
334+
self._state_storage_controller, StateFilter.all()
335+
)
316336

317337
key = (event.type, event.state_key)
318338
if key in state_ids_before_event:
@@ -372,17 +392,14 @@ async def resolve_state_groups_for_events(
372392
state_group_ids_set = set(state_group_ids)
373393
if len(state_group_ids_set) == 1:
374394
(state_group_id,) = state_group_ids_set
375-
state = await self._state_storage_controller.get_state_for_groups(
376-
state_group_ids_set
377-
)
378395
(
379396
prev_group,
380397
delta_ids,
381398
) = await self._state_storage_controller.get_state_group_delta(
382399
state_group_id
383400
)
384401
return _StateCacheEntry(
385-
state=state[state_group_id],
402+
state=None,
386403
state_group=state_group_id,
387404
prev_group=prev_group,
388405
delta_ids=delta_ids,

synapse/storage/controllers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,6 @@ def __init__(self, hs: "HomeServer", stores: Databases):
4343

4444
self.persistence = None
4545
if stores.persist_events:
46-
self.persistence = EventsPersistenceStorageController(hs, stores)
46+
self.persistence = EventsPersistenceStorageController(
47+
hs, stores, self.state
48+
)

synapse/storage/controllers/persist_events.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@
4848
from synapse.logging import opentracing
4949
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
5050
from synapse.metrics.background_process_metrics import run_as_background_process
51+
from synapse.storage.controllers.state import StateStorageController
5152
from synapse.storage.databases import Databases
5253
from synapse.storage.databases.main.events import DeltaState
5354
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
55+
from synapse.storage.state import StateFilter
5456
from synapse.types import (
5557
PersistedEventPosition,
5658
RoomStreamToken,
@@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
308310
current state and forward extremity changes.
309311
"""
310312

311-
def __init__(self, hs: "HomeServer", stores: Databases):
313+
def __init__(
314+
self,
315+
hs: "HomeServer",
316+
stores: Databases,
317+
state_controller: StateStorageController,
318+
):
312319
# We ultimately want to split out the state store from the main store,
313320
# so we use separate variables here even though they point to the same
314321
# store for now.
@@ -325,6 +332,7 @@ def __init__(self, hs: "HomeServer", stores: Databases):
325332
self._process_event_persist_queue_task
326333
)
327334
self._state_resolution_handler = hs.get_state_resolution_handler()
335+
self._state_controller = state_controller
328336

329337
async def _process_event_persist_queue_task(
330338
self,
@@ -504,7 +512,7 @@ async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
504512
state_res_store=StateResolutionStore(self.main_store),
505513
)
506514

507-
return res.state
515+
return await res.get_state(self._state_controller, StateFilter.all())
508516

509517
async def _persist_event_batch(
510518
self, _room_id: str, task: _PersistEventsTask

synapse/storage/databases/main/roommember.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
from synapse.api.constants import EventTypes, Membership
3333
from synapse.events import EventBase
34-
from synapse.events.snapshot import EventContext
3534
from synapse.metrics import LaterGauge
3635
from synapse.metrics.background_process_metrics import (
3736
run_as_background_process,
@@ -780,26 +779,8 @@ async def get_mutual_rooms_between_users(
780779

781780
return shared_room_ids or frozenset()
782781

783-
async def get_joined_users_from_context(
784-
self, event: EventBase, context: EventContext
785-
) -> Dict[str, ProfileInfo]:
786-
state_group: Union[object, int] = context.state_group
787-
if not state_group:
788-
# If state_group is None it means it has yet to be assigned a
789-
# state group, i.e. we need to make sure that calls with a state_group
790-
# of None don't hit previous cached calls with a None state_group.
791-
# To do this we set the state_group to a new object as object() != object()
792-
state_group = object()
793-
794-
current_state_ids = await context.get_current_state_ids()
795-
assert current_state_ids is not None
796-
assert state_group is not None
797-
return await self._get_joined_users_from_context(
798-
event.room_id, state_group, current_state_ids, event=event, context=context
799-
)
800-
801782
async def get_joined_users_from_state(
802-
self, room_id: str, state_entry: "_StateCacheEntry"
783+
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
803784
) -> Dict[str, ProfileInfo]:
804785
state_group: Union[object, int] = state_entry.state_group
805786
if not state_group:
@@ -812,18 +793,17 @@ async def get_joined_users_from_state(
812793
assert state_group is not None
813794
with Measure(self._clock, "get_joined_users_from_state"):
814795
return await self._get_joined_users_from_context(
815-
room_id, state_group, state_entry.state, context=state_entry
796+
room_id, state_group, state, context=state_entry
816797
)
817798

818-
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
799+
@cached(num_args=2, iterable=True, max_entries=100000)
819800
async def _get_joined_users_from_context(
820801
self,
821802
room_id: str,
822803
state_group: Union[object, int],
823804
current_state_ids: StateMap[str],
824-
cache_context: _CacheContext,
825805
event: Optional[EventBase] = None,
826-
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
806+
context: Optional["_StateCacheEntry"] = None,
827807
) -> Dict[str, ProfileInfo]:
828808
# We don't use `state_group`, it's there so that we can cache based
829809
# on it. However, it's important that it's never None, since two current_states
@@ -1017,7 +997,7 @@ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
1017997
)
1018998

1019999
async def get_joined_hosts(
1020-
self, room_id: str, state_entry: "_StateCacheEntry"
1000+
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
10211001
) -> FrozenSet[str]:
10221002
state_group: Union[object, int] = state_entry.state_group
10231003
if not state_group:
@@ -1030,14 +1010,15 @@ async def get_joined_hosts(
10301010
assert state_group is not None
10311011
with Measure(self._clock, "get_joined_hosts"):
10321012
return await self._get_joined_hosts(
1033-
room_id, state_group, state_entry=state_entry
1013+
room_id, state_group, state, state_entry=state_entry
10341014
)
10351015

10361016
@cached(num_args=2, max_entries=10000, iterable=True)
10371017
async def _get_joined_hosts(
10381018
self,
10391019
room_id: str,
10401020
state_group: Union[object, int],
1021+
state: StateMap[str],
10411022
state_entry: "_StateCacheEntry",
10421023
) -> FrozenSet[str]:
10431024
# We don't use `state_group`, it's there so that we can cache based on
@@ -1093,7 +1074,7 @@ async def _get_joined_hosts(
10931074
# The cache doesn't match the state group or prev state group,
10941075
# so we calculate the result from first principles.
10951076
joined_users = await self.get_joined_users_from_state(
1096-
room_id, state_entry
1077+
room_id, state, state_entry
10971078
)
10981079

10991080
cache.hosts_to_joined_users = {}

0 commit comments

Comments
 (0)