Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 13341dd

Browse files
authored
Don't hold onto full state in state cache (#13324)
1 parent 10e4093 commit 13341dd

File tree

2 files changed

+54
-15
lines changed

2 files changed

+54
-15
lines changed

changelog.d/13324.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Reduce the amount of state we store in the `state_cache`.

synapse/state/__init__.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import heapq
1616
import logging
17-
from collections import defaultdict
17+
from collections import ChainMap, defaultdict
1818
from typing import (
1919
TYPE_CHECKING,
2020
Any,
@@ -92,8 +92,11 @@ def __init__(
9292
prev_group: Optional[int] = None,
9393
delta_ids: Optional[StateMap[str]] = None,
9494
):
95-
if state is None and state_group is None:
96-
raise Exception("Either state or state group must be not None")
95+
if state is None and state_group is None and prev_group is None:
96+
raise Exception("One of state, state_group or prev_group must be not None")
97+
98+
if prev_group is not None and delta_ids is None:
99+
raise Exception("If prev_group is set so must delta_ids")
97100

98101
# A map from (type, state_key) to event_id.
99102
#
@@ -120,18 +123,48 @@ async def get_state(
120123
if self._state is not None:
121124
return self._state
122125

123-
assert self.state_group is not None
126+
if self.state_group is not None:
127+
return await state_storage.get_state_ids_for_group(
128+
self.state_group, state_filter
129+
)
130+
131+
assert self.prev_group is not None and self.delta_ids is not None
124132

125-
return await state_storage.get_state_ids_for_group(
126-
self.state_group, state_filter
133+
prev_state = await state_storage.get_state_ids_for_group(
134+
self.prev_group, state_filter
127135
)
128136

137+
# ChainMap expects MutableMapping, but since we're using it immutably
138+
# its safe to give it immutable maps.
139+
return ChainMap(self.delta_ids, prev_state) # type: ignore[arg-type]
140+
141+
def set_state_group(self, state_group: int) -> None:
142+
"""Update the state group assigned to this state (e.g. after we've
143+
persisted it).
144+
145+
Note: this will cause the cache entry to drop any stored state.
146+
"""
147+
148+
self.state_group = state_group
149+
150+
# We clear out the state as we know longer need to explicitly keep it in
151+
# the `state_cache` (as the store state group cache will do that).
152+
self._state = None
153+
129154
def __len__(self) -> int:
130-
# The len should is used to estimate how large this cache entry is, for
131-
# cache eviction purposes. This is why if `self.state` is None it's fine
132-
# to return 1.
155+
# The len should be used to estimate how large this cache entry is, for
156+
# cache eviction purposes. This is why it's fine to return 1 if we're
157+
# not storing any state.
158+
159+
length = 0
133160

134-
return len(self._state) if self._state else 1
161+
if self._state:
162+
length += len(self._state)
163+
164+
if self.delta_ids:
165+
length += len(self.delta_ids)
166+
167+
return length or 1 # Make sure its not 0.
135168

136169

137170
class StateHandler:
@@ -320,7 +353,7 @@ async def compute_event_context(
320353
current_state_ids=state_ids_before_event,
321354
)
322355
)
323-
entry.state_group = state_group_before_event
356+
entry.set_state_group(state_group_before_event)
324357
else:
325358
state_group_before_event = entry.state_group
326359

@@ -747,7 +780,7 @@ def _make_state_cache_entry(
747780
old_state_event_ids = set(state.values())
748781
if new_state_event_ids == old_state_event_ids:
749782
# got an exact match.
750-
return _StateCacheEntry(state=new_state, state_group=sg)
783+
return _StateCacheEntry(state=None, state_group=sg)
751784

752785
# TODO: We want to create a state group for this set of events, to
753786
# increase cache hits, but we need to make sure that it doesn't
@@ -769,9 +802,14 @@ def _make_state_cache_entry(
769802
prev_group = old_group
770803
delta_ids = n_delta_ids
771804

772-
return _StateCacheEntry(
773-
state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
774-
)
805+
if prev_group is not None:
806+
# If we have a prev group and deltas then we can drop the new state from
807+
# the cache (to reduce memory usage).
808+
return _StateCacheEntry(
809+
state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
810+
)
811+
else:
812+
return _StateCacheEntry(state=new_state, state_group=None)
775813

776814

777815
@attr.s(slots=True, auto_attribs=True)

0 commit comments

Comments
 (0)