Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 3eba047

Browse files
authored
Add type hints to state database module. (#10823)
1 parent b932590 commit 3eba047

File tree

6 files changed

+133
-72
lines changed

6 files changed

+133
-72
lines changed

changelog.d/10823.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add type hints to the state database.

mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ files =
6060
synapse/storage/databases/main/session.py,
6161
synapse/storage/databases/main/stream.py,
6262
synapse/storage/databases/main/ui_auth.py,
63+
synapse/storage/databases/state,
6364
synapse/storage/database.py,
6465
synapse/storage/engines,
6566
synapse/storage/keys.py,

synapse/storage/databases/state/bg_updates.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,20 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Optional
16+
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
1717

1818
from synapse.storage._base import SQLBaseStore
19-
from synapse.storage.database import DatabasePool
19+
from synapse.storage.database import (
20+
DatabasePool,
21+
LoggingDatabaseConnection,
22+
LoggingTransaction,
23+
)
2024
from synapse.storage.engines import PostgresEngine
2125
from synapse.storage.state import StateFilter
26+
from synapse.types import MutableStateMap, StateMap
27+
28+
if TYPE_CHECKING:
29+
from synapse.server import HomeServer
2230

2331
logger = logging.getLogger(__name__)
2432

@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
3139
updates.
3240
"""
3341

34-
def _count_state_group_hops_txn(self, txn, state_group):
42+
def _count_state_group_hops_txn(
43+
self, txn: LoggingTransaction, state_group: int
44+
) -> int:
3545
"""Given a state group, count how many hops there are in the tree.
3646
3747
This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ def _count_state_group_hops_txn(self, txn, state_group):
5666
else:
5767
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
5868
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
59-
next_group = state_group
69+
next_group: Optional[int] = state_group
6070
count = 0
6171

6272
while next_group:
@@ -73,11 +83,14 @@ def _count_state_group_hops_txn(self, txn, state_group):
7383
return count
7484

7585
def _get_state_groups_from_groups_txn(
76-
self, txn, groups, state_filter: Optional[StateFilter] = None
77-
):
86+
self,
87+
txn: LoggingTransaction,
88+
groups: List[int],
89+
state_filter: Optional[StateFilter] = None,
90+
) -> Mapping[int, StateMap[str]]:
7891
state_filter = state_filter or StateFilter.all()
7992

80-
results = {group: {} for group in groups}
93+
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
8194

8295
where_clause, where_args = state_filter.make_sql_filter_clause()
8396

@@ -117,7 +130,7 @@ def _get_state_groups_from_groups_txn(
117130
"""
118131

119132
for group in groups:
120-
args = [group]
133+
args: List[Union[int, str]] = [group]
121134
args.extend(where_args)
122135

123136
txn.execute(sql % (where_clause,), args)
@@ -131,7 +144,7 @@ def _get_state_groups_from_groups_txn(
131144
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
132145
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
133146
for group in groups:
134-
next_group = group
147+
next_group: Optional[int] = group
135148

136149
while next_group:
137150
# We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ def _get_state_groups_from_groups_txn(
173186
allow_none=True,
174187
)
175188

189+
# The results shouldn't be considered mutable.
176190
return results
177191

178192

@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
182196
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
183197
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
184198

185-
def __init__(self, database: DatabasePool, db_conn, hs):
199+
def __init__(
200+
self,
201+
database: DatabasePool,
202+
db_conn: LoggingDatabaseConnection,
203+
hs: "HomeServer",
204+
):
186205
super().__init__(database, db_conn, hs)
187206
self.db_pool.updates.register_background_update_handler(
188207
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@@ -198,7 +217,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
198217
columns=["room_id"],
199218
)
200219

201-
async def _background_deduplicate_state(self, progress, batch_size):
220+
async def _background_deduplicate_state(
221+
self, progress: dict, batch_size: int
222+
) -> int:
202223
"""This background update will slowly deduplicate state by reencoding
203224
them as deltas.
204225
"""
@@ -218,7 +239,7 @@ async def _background_deduplicate_state(self, progress, batch_size):
218239
)
219240
max_group = rows[0][0]
220241

221-
def reindex_txn(txn):
242+
def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
222243
new_last_state_group = last_state_group
223244
for count in range(batch_size):
224245
txn.execute(
@@ -251,7 +272,8 @@ def reindex_txn(txn):
251272
" WHERE id < ? AND room_id = ?",
252273
(state_group, room_id),
253274
)
254-
(prev_group,) = txn.fetchone()
275+
# There will be a result due to the coalesce.
276+
(prev_group,) = txn.fetchone() # type: ignore
255277
new_last_state_group = state_group
256278

257279
if prev_group:
@@ -261,15 +283,15 @@ def reindex_txn(txn):
261283
# otherwise read performance degrades.
262284
continue
263285

264-
prev_state = self._get_state_groups_from_groups_txn(
286+
prev_state_by_group = self._get_state_groups_from_groups_txn(
265287
txn, [prev_group]
266288
)
267-
prev_state = prev_state[prev_group]
289+
prev_state = prev_state_by_group[prev_group]
268290

269-
curr_state = self._get_state_groups_from_groups_txn(
291+
curr_state_by_group = self._get_state_groups_from_groups_txn(
270292
txn, [state_group]
271293
)
272-
curr_state = curr_state[state_group]
294+
curr_state = curr_state_by_group[state_group]
273295

274296
if not set(prev_state.keys()) - set(curr_state.keys()):
275297
# We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ def reindex_txn(txn):
340362

341363
return result * BATCH_SIZE_SCALE_FACTOR
342364

343-
async def _background_index_state(self, progress, batch_size):
344-
def reindex_txn(conn):
365+
async def _background_index_state(self, progress: dict, batch_size: int) -> int:
366+
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
345367
conn.rollback()
346368
if isinstance(self.database_engine, PostgresEngine):
347369
# postgres insists on autocommit for the index

0 commit comments

Comments
 (0)