Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit ac95167

Browse files
authored
Add some type hints to datastore. (#12255)
1 parent 4ba55a6 commit ac95167

File tree

10 files changed

+61
-42
lines changed

10 files changed

+61
-42
lines changed

changelog.d/12255.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints for storage.

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,11 @@ exclude = (?x)
3838
|synapse/_scripts/update_synapse_database.py
3939

4040
|synapse/storage/databases/__init__.py
41-
|synapse/storage/databases/main/__init__.py
4241
|synapse/storage/databases/main/cache.py
4342
|synapse/storage/databases/main/devices.py
4443
|synapse/storage/databases/main/event_federation.py
4544
|synapse/storage/databases/main/push_rule.py
46-
|synapse/storage/databases/main/receipts.py
4745
|synapse/storage/databases/main/roommember.py
48-
|synapse/storage/databases/main/search.py
4946
|synapse/storage/databases/main/state.py
5047
|synapse/storage/schema/
5148

synapse/storage/databases/main/media_repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __init__(
156156
hs: "HomeServer",
157157
):
158158
super().__init__(database, db_conn, hs)
159-
self.server_name = hs.hostname
159+
self.server_name: str = hs.hostname
160160

161161
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
162162
"""Get the metadata for a local piece of media

synapse/storage/databases/main/receipts.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424
Optional,
2525
Set,
2626
Tuple,
27+
cast,
2728
)
2829

29-
from twisted.internet import defer
30-
3130
from synapse.api.constants import ReceiptTypes
3231
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
3332
from synapse.replication.tcp.streams import ReceiptsStream
@@ -38,7 +37,11 @@
3837
LoggingTransaction,
3938
)
4039
from synapse.storage.engines import PostgresEngine
41-
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
40+
from synapse.storage.util.id_generators import (
41+
AbstractStreamIdTracker,
42+
MultiWriterIdGenerator,
43+
StreamIdGenerator,
44+
)
4245
from synapse.types import JsonDict
4346
from synapse.util import json_encoder
4447
from synapse.util.caches.descriptors import cached, cachedList
@@ -58,6 +61,7 @@ def __init__(
5861
hs: "HomeServer",
5962
):
6063
self._instance_name = hs.get_instance_name()
64+
self._receipts_id_gen: AbstractStreamIdTracker
6165

6266
if isinstance(database.engine, PostgresEngine):
6367
self._can_write_to_receipts = (
@@ -161,7 +165,7 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
161165
" AND user_id = ?"
162166
)
163167
txn.execute(sql, (user_id,))
164-
return txn.fetchall()
168+
return cast(List[Tuple[str, str, int, int]], txn.fetchall())
165169

166170
rows = await self.db_pool.runInteraction(
167171
"get_receipts_for_user_with_orderings", f
@@ -257,7 +261,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
257261
if not rows:
258262
return []
259263

260-
content = {}
264+
content: JsonDict = {}
261265
for row in rows:
262266
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
263267
row["user_id"]
@@ -305,7 +309,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
305309
"_get_linearized_receipts_for_rooms", f
306310
)
307311

308-
results = {}
312+
results: JsonDict = {}
309313
for row in txn_results:
310314
# We want a single event per room, since we want to batch the
311315
# receipts by room, event and type.
@@ -370,7 +374,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
370374
"get_linearized_receipts_for_all_rooms", f
371375
)
372376

373-
results = {}
377+
results: JsonDict = {}
374378
for row in txn_results:
375379
# We want a single event per room, since we want to batch the
376380
# receipts by room, event and type.
@@ -399,7 +403,7 @@ async def get_users_sent_receipts_between(
399403
"""
400404

401405
if last_id == current_id:
402-
return defer.succeed([])
406+
return []
403407

404408
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
405409
sql = """
@@ -453,7 +457,10 @@ def get_all_updated_receipts_txn(
453457
"""
454458
txn.execute(sql, (last_id, current_id, limit))
455459

456-
updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
460+
updates = cast(
461+
List[Tuple[int, list]],
462+
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
463+
)
457464

458465
limited = False
459466
upper_bound = current_id
@@ -496,7 +503,13 @@ def invalidate_caches_for_receipt(
496503
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
497504
self.get_receipts_for_room.invalidate((room_id, receipt_type))
498505

499-
def process_replication_rows(self, stream_name, instance_name, token, rows):
506+
def process_replication_rows(
507+
self,
508+
stream_name: str,
509+
instance_name: str,
510+
token: int,
511+
rows: Iterable[Any],
512+
) -> None:
500513
if stream_name == ReceiptsStream.NAME:
501514
self._receipts_id_gen.advance(instance_name, token)
502515
for row in rows:
@@ -584,7 +597,7 @@ def insert_linearized_receipt_txn(
584597
)
585598

586599
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
587-
self._remove_old_push_actions_before_txn(
600+
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
588601
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
589602
)
590603

@@ -637,7 +650,7 @@ def graph_to_linear(txn: LoggingTransaction) -> str:
637650
"insert_receipt_conv", graph_to_linear
638651
)
639652

640-
async with self._receipts_id_gen.get_next() as stream_id:
653+
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
641654
event_ts = await self.db_pool.runInteraction(
642655
"insert_linearized_receipt",
643656
self.insert_linearized_receipt_txn,

synapse/storage/databases/main/registration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from synapse.api.constants import UserTypes
2424
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
25+
from synapse.config.homeserver import HomeServerConfig
2526
from synapse.metrics.background_process_metrics import wrap_as_background_process
2627
from synapse.storage.database import (
2728
DatabasePool,
@@ -123,7 +124,7 @@ def __init__(
123124
):
124125
super().__init__(database, db_conn, hs)
125126

126-
self.config = hs.config
127+
self.config: HomeServerConfig = hs.config
127128

128129
# Note: we don't check this sequence for consistency as we'd have to
129130
# call `find_max_generated_user_id_localpart` each time, which is

synapse/storage/databases/main/room.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
3535
from synapse.api.errors import StoreError
3636
from synapse.api.room_versions import RoomVersion, RoomVersions
37+
from synapse.config.homeserver import HomeServerConfig
3738
from synapse.events import EventBase
3839
from synapse.storage._base import SQLBaseStore, db_to_json
3940
from synapse.storage.database import (
@@ -98,7 +99,7 @@ def __init__(
9899
):
99100
super().__init__(database, db_conn, hs)
100101

101-
self.config = hs.config
102+
self.config: HomeServerConfig = hs.config
102103

103104
async def store_room(
104105
self,

synapse/storage/databases/main/search.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
import re
17-
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
17+
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
1818

1919
import attr
2020

@@ -74,7 +74,7 @@ def store_search_entries_txn(
7474
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
7575
)
7676

77-
args = (
77+
args1 = (
7878
(
7979
entry.event_id,
8080
entry.room_id,
@@ -86,14 +86,14 @@ def store_search_entries_txn(
8686
for entry in entries
8787
)
8888

89-
txn.execute_batch(sql, args)
89+
txn.execute_batch(sql, args1)
9090

9191
elif isinstance(self.database_engine, Sqlite3Engine):
9292
sql = (
9393
"INSERT INTO event_search (event_id, room_id, key, value)"
9494
" VALUES (?,?,?,?)"
9595
)
96-
args = (
96+
args2 = (
9797
(
9898
entry.event_id,
9999
entry.room_id,
@@ -102,7 +102,7 @@ def store_search_entries_txn(
102102
)
103103
for entry in entries
104104
)
105-
txn.execute_batch(sql, args)
105+
txn.execute_batch(sql, args2)
106106

107107
else:
108108
# This should be unreachable.
@@ -427,7 +427,7 @@ async def search_msgs(
427427

428428
search_query = _parse_query(self.database_engine, search_term)
429429

430-
args = []
430+
args: List[Any] = []
431431

432432
# Make sure we don't explode because the person is in too many rooms.
433433
# We filter the results below regardless.
@@ -496,7 +496,7 @@ async def search_msgs(
496496

497497
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
498498
# search results (which is a data leak)
499-
events = await self.get_events_as_list(
499+
events = await self.get_events_as_list( # type: ignore[attr-defined]
500500
[r["event_id"] for r in results],
501501
redact_behaviour=EventRedactBehaviour.BLOCK,
502502
)
@@ -530,7 +530,7 @@ async def search_rooms(
530530
room_ids: Collection[str],
531531
search_term: str,
532532
keys: Iterable[str],
533-
limit,
533+
limit: int,
534534
pagination_token: Optional[str] = None,
535535
) -> JsonDict:
536536
"""Performs a full text search over events with given keys.
@@ -549,7 +549,7 @@ async def search_rooms(
549549

550550
search_query = _parse_query(self.database_engine, search_term)
551551

552-
args = []
552+
args: List[Any] = []
553553

554554
# Make sure we don't explode because the person is in too many rooms.
555555
# We filter the results below regardless.
@@ -573,9 +573,9 @@ async def search_rooms(
573573

574574
if pagination_token:
575575
try:
576-
origin_server_ts, stream = pagination_token.split(",")
577-
origin_server_ts = int(origin_server_ts)
578-
stream = int(stream)
576+
origin_server_ts_str, stream_str = pagination_token.split(",")
577+
origin_server_ts = int(origin_server_ts_str)
578+
stream = int(stream_str)
579579
except Exception:
580580
raise SynapseError(400, "Invalid pagination token")
581581

@@ -654,7 +654,7 @@ async def search_rooms(
654654

655655
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
656656
# search results (which is a data leak)
657-
events = await self.get_events_as_list(
657+
events = await self.get_events_as_list( # type: ignore[attr-defined]
658658
[r["event_id"] for r in results],
659659
redact_behaviour=EventRedactBehaviour.BLOCK,
660660
)

synapse/storage/databases/main/state.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import collections.abc
1616
import logging
17-
from typing import TYPE_CHECKING, Iterable, Optional, Set
17+
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
1818

1919
from synapse.api.constants import EventTypes, Membership
2020
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +29,7 @@
2929
from synapse.storage.databases.main.events_worker import EventsWorkerStore
3030
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
3131
from synapse.storage.state import StateFilter
32-
from synapse.types import StateMap
32+
from synapse.types import JsonDict, StateMap
3333
from synapse.util.caches import intern_string
3434
from synapse.util.caches.descriptors import cached, cachedList
3535

@@ -241,7 +241,9 @@ async def get_filtered_current_state_ids(
241241
# We delegate to the cached version
242242
return await self.get_current_state_ids(room_id)
243243

244-
def _get_filtered_current_state_ids_txn(txn):
244+
def _get_filtered_current_state_ids_txn(
245+
txn: LoggingTransaction,
246+
) -> StateMap[str]:
245247
results = {}
246248
sql = """
247249
SELECT type, state_key, event_id FROM current_state_events
@@ -281,11 +283,11 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
281283

282284
event_id = state.get((EventTypes.CanonicalAlias, ""))
283285
if not event_id:
284-
return
286+
return None
285287

286288
event = await self.get_event(event_id, allow_none=True)
287289
if not event:
288-
return
290+
return None
289291

290292
return event.content.get("canonical_alias")
291293

@@ -304,7 +306,7 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
304306
list_name="event_ids",
305307
num_args=1,
306308
)
307-
async def _get_state_group_for_events(self, event_ids):
309+
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
308310
"""Returns mapping event_id -> state_group"""
309311
rows = await self.db_pool.simple_select_many_batch(
310312
table="event_to_state_groups",
@@ -355,7 +357,7 @@ def __init__(
355357
):
356358
super().__init__(database, db_conn, hs)
357359

358-
self.server_name = hs.hostname
360+
self.server_name: str = hs.hostname
359361

360362
self.db_pool.updates.register_background_index_update(
361363
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@@ -375,15 +377,19 @@ def __init__(
375377
self._background_remove_left_rooms,
376378
)
377379

378-
async def _background_remove_left_rooms(self, progress, batch_size):
380+
async def _background_remove_left_rooms(
381+
self, progress: JsonDict, batch_size: int
382+
) -> int:
379383
"""Background update to delete rows from `current_state_events` and
380384
`event_forward_extremities` tables of rooms that the server is no
381385
longer joined to.
382386
"""
383387

384388
last_room_id = progress.get("last_room_id", "")
385389

386-
def _background_remove_left_rooms_txn(txn):
390+
def _background_remove_left_rooms_txn(
391+
txn: LoggingTransaction,
392+
) -> Tuple[bool, Set[str]]:
387393
# get a batch of room ids to consider
388394
sql = """
389395
SELECT DISTINCT room_id FROM current_state_events

synapse/storage/databases/main/stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
):
109109
super().__init__(database, db_conn, hs)
110110

111-
self.server_name = hs.hostname
111+
self.server_name: str = hs.hostname
112112
self.clock = self.hs.get_clock()
113113
self.stats_enabled = hs.config.stats.stats_enabled
114114

synapse/storage/databases/main/user_directory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
) -> None:
6969
super().__init__(database, db_conn, hs)
7070

71-
self.server_name = hs.hostname
71+
self.server_name: str = hs.hostname
7272

7373
self.db_pool.updates.register_background_update_handler(
7474
"populate_user_directory_createtables",

0 commit comments

Comments
 (0)