3737
3838import logging
3939import random
40- from typing import Tuple
40+ from typing import TYPE_CHECKING , Optional , Tuple
4141
4242from signedjson .sign import sign_json
4343
4444from synapse .api .errors import HttpResponseException , RequestSendFailed , SynapseError
4545from synapse .metrics .background_process_metrics import run_as_background_process
46- from synapse .types import get_domain_from_id
46+ from synapse .types import JsonDict , get_domain_from_id
47+
48+ if TYPE_CHECKING :
49+ from synapse .app .homeserver import HomeServer
4750
4851logger = logging .getLogger (__name__ )
4952
6366class GroupAttestationSigning :
6467 """Creates and verifies group attestations."""
6568
66- def __init__ (self , hs ):
69+ def __init__ (self , hs : "HomeServer" ):
6770 self .keyring = hs .get_keyring ()
6871 self .clock = hs .get_clock ()
6972 self .server_name = hs .hostname
7073 self .signing_key = hs .signing_key
7174
7275 async def verify_attestation (
73- self , attestation , group_id , user_id , server_name = None
74- ):
76+ self ,
77+ attestation : JsonDict ,
78+ group_id : str ,
79+ user_id : str ,
80+ server_name : Optional [str ] = None ,
81+ ) -> None :
7582 """Verifies that the given attestation matches the given parameters.
7683
7784 An optional server_name can be supplied to explicitly set which server's
@@ -100,16 +107,18 @@ async def verify_attestation(
100107 if valid_until_ms < now :
101108 raise SynapseError (400 , "Attestation expired" )
102109
110+ assert server_name is not None
103111 await self .keyring .verify_json_for_server (
104112 server_name , attestation , now , "Group attestation"
105113 )
106114
107- def create_attestation (self , group_id , user_id ) :
115+ def create_attestation (self , group_id : str , user_id : str ) -> JsonDict :
108116 """Create an attestation for the group_id and user_id with default
109117 validity length.
110118 """
111- validity_period = DEFAULT_ATTESTATION_LENGTH_MS
112- validity_period *= random .uniform (* DEFAULT_ATTESTATION_JITTER )
119+ validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random .uniform (
120+ * DEFAULT_ATTESTATION_JITTER
121+ )
113122 valid_until_ms = int (self .clock .time_msec () + validity_period )
114123
115124 return sign_json (
@@ -126,7 +135,7 @@ def create_attestation(self, group_id, user_id):
126135class GroupAttestionRenewer :
127136 """Responsible for sending and receiving attestation updates."""
128137
129- def __init__ (self , hs ):
138+ def __init__ (self , hs : "HomeServer" ):
130139 self .clock = hs .get_clock ()
131140 self .store = hs .get_datastore ()
132141 self .assestations = hs .get_groups_attestation_signing ()
@@ -139,7 +148,9 @@ def __init__(self, hs):
139148 self ._start_renew_attestations , 30 * 60 * 1000
140149 )
141150
142- async def on_renew_attestation (self , group_id , user_id , content ):
151+ async def on_renew_attestation (
152+ self , group_id : str , user_id : str , content : JsonDict
153+ ) -> JsonDict :
143154 """When a remote updates an attestation"""
144155 attestation = content ["attestation" ]
145156
@@ -154,10 +165,10 @@ async def on_renew_attestation(self, group_id, user_id, content):
154165
155166 return {}
156167
157- def _start_renew_attestations (self ):
168+ def _start_renew_attestations (self ) -> None :
158169 return run_as_background_process ("renew_attestations" , self ._renew_attestations )
159170
160- async def _renew_attestations (self ):
171+ async def _renew_attestations (self ) -> None :
161172 """Called periodically to check if we need to update any of our attestations"""
162173
163174 now = self .clock .time_msec ()
@@ -166,7 +177,7 @@ async def _renew_attestations(self):
166177 now + UPDATE_ATTESTATION_TIME_MS
167178 )
168179
169- async def _renew_attestation (group_user : Tuple [str , str ]):
180+ async def _renew_attestation (group_user : Tuple [str , str ]) -> None :
170181 group_id , user_id = group_user
171182 try :
172183 if not self .is_mine_id (group_id ):
0 commit comments