1313# limitations under the License.
1414
1515import logging
16- from typing import Optional
16+ from typing import TYPE_CHECKING , Dict , List , Mapping , Optional , Tuple , Union
1717
1818from synapse .storage ._base import SQLBaseStore
19- from synapse .storage .database import DatabasePool
19+ from synapse .storage .database import (
20+ DatabasePool ,
21+ LoggingDatabaseConnection ,
22+ LoggingTransaction ,
23+ )
2024from synapse .storage .engines import PostgresEngine
2125from synapse .storage .state import StateFilter
26+ from synapse .types import MutableStateMap , StateMap
27+
28+ if TYPE_CHECKING :
29+ from synapse .server import HomeServer
2230
2331logger = logging .getLogger (__name__ )
2432
@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
3139 updates.
3240 """
3341
34- def _count_state_group_hops_txn (self , txn , state_group ):
42+ def _count_state_group_hops_txn (
43+ self , txn : LoggingTransaction , state_group : int
44+ ) -> int :
3545 """Given a state group, count how many hops there are in the tree.
3646
3747 This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ def _count_state_group_hops_txn(self, txn, state_group):
5666 else :
5767 # We don't use WITH RECURSIVE on sqlite3 as there are distributions
5868 # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
59- next_group = state_group
69+ next_group : Optional [ int ] = state_group
6070 count = 0
6171
6272 while next_group :
@@ -73,11 +83,14 @@ def _count_state_group_hops_txn(self, txn, state_group):
7383 return count
7484
7585 def _get_state_groups_from_groups_txn (
76- self , txn , groups , state_filter : Optional [StateFilter ] = None
77- ):
86+ self ,
87+ txn : LoggingTransaction ,
88+ groups : List [int ],
89+ state_filter : Optional [StateFilter ] = None ,
90+ ) -> Mapping [int , StateMap [str ]]:
7891 state_filter = state_filter or StateFilter .all ()
7992
80- results = {group : {} for group in groups }
93+ results : Dict [ int , MutableStateMap [ str ]] = {group : {} for group in groups }
8194
8295 where_clause , where_args = state_filter .make_sql_filter_clause ()
8396
@@ -117,7 +130,7 @@ def _get_state_groups_from_groups_txn(
117130 """
118131
119132 for group in groups :
120- args = [group ]
133+ args : List [ Union [ int , str ]] = [group ]
121134 args .extend (where_args )
122135
123136 txn .execute (sql % (where_clause ,), args )
@@ -131,7 +144,7 @@ def _get_state_groups_from_groups_txn(
131144 # We don't use WITH RECURSIVE on sqlite3 as there are distributions
132145 # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
133146 for group in groups :
134- next_group = group
147+ next_group : Optional [ int ] = group
135148
136149 while next_group :
137150 # We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ def _get_state_groups_from_groups_txn(
173186 allow_none = True ,
174187 )
175188
189+ # The results shouldn't be considered mutable.
176190 return results
177191
178192
@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
182196 STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
183197 STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
184198
185- def __init__ (self , database : DatabasePool , db_conn , hs ):
199+ def __init__ (
200+ self ,
201+ database : DatabasePool ,
202+ db_conn : LoggingDatabaseConnection ,
203+ hs : "HomeServer" ,
204+ ):
186205 super ().__init__ (database , db_conn , hs )
187206 self .db_pool .updates .register_background_update_handler (
188207 self .STATE_GROUP_DEDUPLICATION_UPDATE_NAME ,
@@ -198,7 +217,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
198217 columns = ["room_id" ],
199218 )
200219
201- async def _background_deduplicate_state (self , progress , batch_size ):
220+ async def _background_deduplicate_state (
221+ self , progress : dict , batch_size : int
222+ ) -> int :
202223 """This background update will slowly deduplicate state by reencoding
203224 them as deltas.
204225 """
@@ -218,7 +239,7 @@ async def _background_deduplicate_state(self, progress, batch_size):
218239 )
219240 max_group = rows [0 ][0 ]
220241
221- def reindex_txn (txn ) :
242+ def reindex_txn (txn : LoggingTransaction ) -> Tuple [ bool , int ] :
222243 new_last_state_group = last_state_group
223244 for count in range (batch_size ):
224245 txn .execute (
@@ -251,7 +272,8 @@ def reindex_txn(txn):
251272 " WHERE id < ? AND room_id = ?" ,
252273 (state_group , room_id ),
253274 )
254- (prev_group ,) = txn .fetchone ()
275+ # There will be a result due to the coalesce.
276+ (prev_group ,) = txn .fetchone () # type: ignore
255277 new_last_state_group = state_group
256278
257279 if prev_group :
@@ -261,15 +283,15 @@ def reindex_txn(txn):
261283 # otherwise read performance degrades.
262284 continue
263285
264- prev_state = self ._get_state_groups_from_groups_txn (
286+ prev_state_by_group = self ._get_state_groups_from_groups_txn (
265287 txn , [prev_group ]
266288 )
267- prev_state = prev_state [prev_group ]
289+ prev_state = prev_state_by_group [prev_group ]
268290
269- curr_state = self ._get_state_groups_from_groups_txn (
291+ curr_state_by_group = self ._get_state_groups_from_groups_txn (
270292 txn , [state_group ]
271293 )
272- curr_state = curr_state [state_group ]
294+ curr_state = curr_state_by_group [state_group ]
273295
274296 if not set (prev_state .keys ()) - set (curr_state .keys ()):
275297 # We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ def reindex_txn(txn):
340362
341363 return result * BATCH_SIZE_SCALE_FACTOR
342364
343- async def _background_index_state (self , progress , batch_size ) :
344- def reindex_txn (conn ) :
365+ async def _background_index_state (self , progress : dict , batch_size : int ) -> int :
366+ def reindex_txn (conn : LoggingDatabaseConnection ) -> None :
345367 conn .rollback ()
346368 if isinstance (self .database_engine , PostgresEngine ):
347369 # postgres insists on autocommit for the index
0 commit comments