1414# limitations under the License.
1515import heapq
1616import logging
17- from collections import defaultdict
17+ from collections import ChainMap , defaultdict
1818from typing import (
1919 TYPE_CHECKING ,
2020 Any ,
@@ -92,8 +92,11 @@ def __init__(
9292 prev_group : Optional [int ] = None ,
9393 delta_ids : Optional [StateMap [str ]] = None ,
9494 ):
95- if state is None and state_group is None :
96- raise Exception ("Either state or state group must be not None" )
95+ if state is None and state_group is None and prev_group is None :
96+ raise Exception ("One of state, state_group or prev_group must be not None" )
97+
98+ if prev_group is not None and delta_ids is None :
99+ raise Exception ("If prev_group is set so must delta_ids" )
97100
98101 # A map from (type, state_key) to event_id.
99102 #
@@ -120,18 +123,48 @@ async def get_state(
120123 if self ._state is not None :
121124 return self ._state
122125
123- assert self .state_group is not None
126+ if self .state_group is not None :
127+ return await state_storage .get_state_ids_for_group (
128+ self .state_group , state_filter
129+ )
130+
131+ assert self .prev_group is not None and self .delta_ids is not None
124132
125- return await state_storage .get_state_ids_for_group (
126- self .state_group , state_filter
133+ prev_state = await state_storage .get_state_ids_for_group (
134+ self .prev_group , state_filter
127135 )
128136
137+ # ChainMap expects MutableMapping, but since we're using it immutably
138+ # its safe to give it immutable maps.
139+ return ChainMap (self .delta_ids , prev_state ) # type: ignore[arg-type]
140+
141+ def set_state_group (self , state_group : int ) -> None :
142+ """Update the state group assigned to this state (e.g. after we've
143+ persisted it).
144+
145+ Note: this will cause the cache entry to drop any stored state.
146+ """
147+
148+ self .state_group = state_group
149+
150+ # We clear out the state as we know longer need to explicitly keep it in
151+ # the `state_cache` (as the store state group cache will do that).
152+ self ._state = None
153+
129154 def __len__ (self ) -> int :
130- # The len should is used to estimate how large this cache entry is, for
131- # cache eviction purposes. This is why if `self.state` is None it's fine
132- # to return 1.
155+ # The len should be used to estimate how large this cache entry is, for
156+ # cache eviction purposes. This is why it's fine to return 1 if we're
157+ # not storing any state.
158+
159+ length = 0
133160
134- return len (self ._state ) if self ._state else 1
161+ if self ._state :
162+ length += len (self ._state )
163+
164+ if self .delta_ids :
165+ length += len (self .delta_ids )
166+
167+ return length or 1 # Make sure its not 0.
135168
136169
137170class StateHandler :
@@ -320,7 +353,7 @@ async def compute_event_context(
320353 current_state_ids = state_ids_before_event ,
321354 )
322355 )
323- entry .state_group = state_group_before_event
356+ entry .set_state_group ( state_group_before_event )
324357 else :
325358 state_group_before_event = entry .state_group
326359
@@ -747,7 +780,7 @@ def _make_state_cache_entry(
747780 old_state_event_ids = set (state .values ())
748781 if new_state_event_ids == old_state_event_ids :
749782 # got an exact match.
750- return _StateCacheEntry (state = new_state , state_group = sg )
783+ return _StateCacheEntry (state = None , state_group = sg )
751784
752785 # TODO: We want to create a state group for this set of events, to
753786 # increase cache hits, but we need to make sure that it doesn't
@@ -769,9 +802,14 @@ def _make_state_cache_entry(
769802 prev_group = old_group
770803 delta_ids = n_delta_ids
771804
772- return _StateCacheEntry (
773- state = new_state , state_group = None , prev_group = prev_group , delta_ids = delta_ids
774- )
805+ if prev_group is not None :
806+ # If we have a prev group and deltas then we can drop the new state from
807+ # the cache (to reduce memory usage).
808+ return _StateCacheEntry (
809+ state = None , state_group = None , prev_group = prev_group , delta_ids = delta_ids
810+ )
811+ else :
812+ return _StateCacheEntry (state = new_state , state_group = None )
775813
776814
777815@attr .s (slots = True , auto_attribs = True )
0 commit comments