2424if TYPE_CHECKING :
2525 from synapse .storage import Storage
2626 from synapse .storage .databases .main import DataStore
27+ from synapse .storage .state import StateFilter
2728
2829
2930@attr .s (slots = True , auto_attribs = True )
@@ -196,14 +197,19 @@ def state_group(self) -> Optional[int]:
196197
197198 return self ._state_group
198199
199- async def get_current_state_ids (self ) -> Optional [StateMap [str ]]:
200+ async def get_current_state_ids (
201+ self , state_filter : Optional ["StateFilter" ] = None
202+ ) -> Optional [StateMap [str ]]:
200203 """
201204 Gets the room state map, including this event - ie, the state in ``state_group``
202205
203206 It is an error to access this for a rejected event, since rejected state should
204207 not make it into the room state. This method will raise an exception if
205208 ``rejected`` is set.
206209
210+ Arg:
211+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
212+
207213 Returns:
208214 Returns None if state_group is None, which happens when the associated
209215 event is an outlier.
@@ -216,20 +222,25 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
216222
217223 assert self ._state_delta_due_to_event is not None
218224
219- prev_state_ids = await self .get_prev_state_ids ()
225+ prev_state_ids = await self .get_prev_state_ids (state_filter )
220226
221227 if self ._state_delta_due_to_event :
222228 prev_state_ids = dict (prev_state_ids )
223229 prev_state_ids .update (self ._state_delta_due_to_event )
224230
225231 return prev_state_ids
226232
227- async def get_prev_state_ids (self ) -> StateMap [str ]:
233+ async def get_prev_state_ids (
234+ self , state_filter : Optional ["StateFilter" ] = None
235+ ) -> StateMap [str ]:
228236 """
229237 Gets the room state map, excluding this event.
230238
231239 For a non-state event, this will be the same as get_current_state_ids().
232240
241+ Args:
242+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
243+
233244 Returns:
234245 Returns {} if state_group is None, which happens when the associated
235246 event is an outlier.
@@ -239,7 +250,7 @@ async def get_prev_state_ids(self) -> StateMap[str]:
239250 """
240251 assert self .state_group_before_event is not None
241252 return await self ._storage .state .get_state_ids_for_group (
242- self .state_group_before_event
253+ self .state_group_before_event , state_filter
243254 )
244255
245256
0 commit comments