1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import TYPE_CHECKING , Dict , Iterable , List , Tuple
15+ from typing import TYPE_CHECKING , Dict , Iterable , List , Tuple , cast
1616
1717from synapse .api .presence import PresenceState , UserPresenceState
1818from synapse .replication .tcp .streams import PresenceStream
1919from synapse .storage ._base import SQLBaseStore , make_in_list_sql_clause
20- from synapse .storage .database import DatabasePool , LoggingDatabaseConnection
20+ from synapse .storage .database import (
21+ DatabasePool ,
22+ LoggingDatabaseConnection ,
23+ LoggingTransaction ,
24+ )
2125from synapse .storage .engines import PostgresEngine
2226from synapse .storage .types import Connection
23- from synapse .storage .util .id_generators import MultiWriterIdGenerator , StreamIdGenerator
27+ from synapse .storage .util .id_generators import (
28+ AbstractStreamIdGenerator ,
29+ MultiWriterIdGenerator ,
30+ StreamIdGenerator ,
31+ )
2432from synapse .util .caches .descriptors import cached , cachedList
2533from synapse .util .caches .stream_change_cache import StreamChangeCache
2634from synapse .util .iterutils import batch_iter
@@ -35,7 +43,7 @@ def __init__(
3543 database : DatabasePool ,
3644 db_conn : LoggingDatabaseConnection ,
3745 hs : "HomeServer" ,
38- ):
46+ ) -> None :
3947 super ().__init__ (database , db_conn , hs )
4048
4149 # Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ def __init__(
5462 database : DatabasePool ,
5563 db_conn : LoggingDatabaseConnection ,
5664 hs : "HomeServer" ,
57- ):
65+ ) -> None :
5866 super ().__init__ (database , db_conn , hs )
5967
68+ self ._instance_name = hs .get_instance_name ()
69+ self ._presence_id_gen : AbstractStreamIdGenerator
70+
6071 self ._can_persist_presence = (
61- hs . get_instance_name () in hs .config .worker .writers .presence
72+ self . _instance_name in hs .config .worker .writers .presence
6273 )
6374
6475 if isinstance (database .engine , PostgresEngine ):
@@ -109,7 +120,9 @@ async def update_presence(self, presence_states) -> Tuple[int, int]:
109120
110121 return stream_orderings [- 1 ], self ._presence_id_gen .get_current_token ()
111122
112- def _update_presence_txn (self , txn , stream_orderings , presence_states ):
123+ def _update_presence_txn (
124+ self , txn : LoggingTransaction , stream_orderings , presence_states
125+ ) -> None :
113126 for stream_id , state in zip (stream_orderings , presence_states ):
114127 txn .call_after (
115128 self .presence_stream_cache .entity_has_changed , state .user_id , stream_id
@@ -183,19 +196,23 @@ async def get_all_presence_updates(
183196 if last_id == current_id :
184197 return [], current_id , False
185198
186- def get_all_presence_updates_txn (txn ):
199+ def get_all_presence_updates_txn (
200+ txn : LoggingTransaction ,
201+ ) -> Tuple [List [Tuple [int , list ]], int , bool ]:
187202 sql = """
188203 SELECT stream_id, user_id, state, last_active_ts,
189204 last_federation_update_ts, last_user_sync_ts,
190- status_msg,
191- currently_active
205+ status_msg, currently_active
192206 FROM presence_stream
193207 WHERE ? < stream_id AND stream_id <= ?
194208 ORDER BY stream_id ASC
195209 LIMIT ?
196210 """
197211 txn .execute (sql , (last_id , current_id , limit ))
198- updates = [(row [0 ], row [1 :]) for row in txn ]
212+ updates = cast (
213+ List [Tuple [int , list ]],
214+ [(row [0 ], row [1 :]) for row in txn ],
215+ )
199216
200217 upper_bound = current_id
201218 limited = False
@@ -210,15 +227,17 @@ def get_all_presence_updates_txn(txn):
210227 )
211228
212229 @cached ()
213- def _get_presence_for_user (self , user_id ) :
230+ def _get_presence_for_user (self , user_id : str ) -> None :
214231 raise NotImplementedError ()
215232
216233 @cachedList (
217234 cached_method_name = "_get_presence_for_user" ,
218235 list_name = "user_ids" ,
219236 num_args = 1 ,
220237 )
221- async def get_presence_for_users (self , user_ids ):
238+ async def get_presence_for_users (
239+ self , user_ids : Iterable [str ]
240+ ) -> Dict [str , UserPresenceState ]:
222241 rows = await self .db_pool .simple_select_many_batch (
223242 table = "presence_stream" ,
224243 column = "user_id" ,
@@ -257,7 +276,9 @@ async def should_user_receive_full_presence_with_token(
257276 True if the user should have full presence sent to them, False otherwise.
258277 """
259278
260- def _should_user_receive_full_presence_with_token_txn (txn ):
279+ def _should_user_receive_full_presence_with_token_txn (
280+ txn : LoggingTransaction ,
281+ ) -> bool :
261282 sql = """
262283 SELECT 1 FROM users_to_send_full_presence_to
263284 WHERE user_id = ?
@@ -271,7 +292,7 @@ def _should_user_receive_full_presence_with_token_txn(txn):
271292 _should_user_receive_full_presence_with_token_txn ,
272293 )
273294
274- async def add_users_to_send_full_presence_to (self , user_ids : Iterable [str ]):
295+ async def add_users_to_send_full_presence_to (self , user_ids : Iterable [str ]) -> None :
275296 """Adds to the list of users who should receive a full snapshot of presence
276297 upon their next sync.
277298
@@ -353,10 +374,10 @@ async def get_presence_for_all_users(
353374
354375 return users_to_state
355376
356- def get_current_presence_token (self ):
377+ def get_current_presence_token (self ) -> int :
357378 return self ._presence_id_gen .get_current_token ()
358379
359- def _get_active_presence (self , db_conn : Connection ):
380+ def _get_active_presence (self , db_conn : Connection ) -> List [ UserPresenceState ] :
360381 """Fetch non-offline presence from the database so that we can register
361382 the appropriate time outs.
362383 """
@@ -379,12 +400,12 @@ def _get_active_presence(self, db_conn: Connection):
379400
380401 return [UserPresenceState (** row ) for row in rows ]
381402
382- def take_presence_startup_info (self ):
403+ def take_presence_startup_info (self ) -> List [ UserPresenceState ] :
383404 active_on_startup = self ._presence_on_startup
384- self ._presence_on_startup = None
405+ self ._presence_on_startup = []
385406 return active_on_startup
386407
387- def process_replication_rows (self , stream_name , instance_name , token , rows ):
408+ def process_replication_rows (self , stream_name , instance_name , token , rows ) -> None :
388409 if stream_name == PresenceStream .NAME :
389410 self ._presence_id_gen .advance (instance_name , token )
390411 for row in rows :
0 commit comments