Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit e836279

Browse files
authored
Add type hints to auth and auth_blocking. (#9876)
1 parent a15c003 commit e836279

File tree

4 files changed

+48
-44
lines changed

4 files changed

+48
-44
lines changed

changelog.d/9876.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.

synapse/api/auth.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import List, Optional, Tuple
15+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
1616

1717
import pymacaroons
1818
from netaddr import IPAddress
1919

2020
from twisted.web.server import Request
2121

22-
import synapse.types
2322
from synapse import event_auth
2423
from synapse.api.auth_blocking import AuthBlocking
2524
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
@@ -36,11 +35,14 @@
3635
from synapse.http.site import SynapseRequest
3736
from synapse.logging import opentracing as opentracing
3837
from synapse.storage.databases.main.registration import TokenLookupResult
39-
from synapse.types import StateMap, UserID
38+
from synapse.types import Requester, StateMap, UserID, create_requester
4039
from synapse.util.caches.lrucache import LruCache
4140
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
4241
from synapse.util.metrics import Measure
4342

43+
if TYPE_CHECKING:
44+
from synapse.server import HomeServer
45+
4446
logger = logging.getLogger(__name__)
4547

4648

@@ -68,7 +70,7 @@ class Auth:
6870
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
6971
"""
7072

71-
def __init__(self, hs):
73+
def __init__(self, hs: "HomeServer"):
7274
self.hs = hs
7375
self.clock = hs.get_clock()
7476
self.store = hs.get_datastore()
@@ -88,13 +90,13 @@ def __init__(self, hs):
8890

8991
async def check_from_context(
9092
self, room_version: str, event, context, do_sig_check=True
91-
):
93+
) -> None:
9294
prev_state_ids = await context.get_prev_state_ids()
9395
auth_events_ids = self.compute_auth_events(
9496
event, prev_state_ids, for_verification=True
9597
)
96-
auth_events = await self.store.get_events(auth_events_ids)
97-
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
98+
auth_events_by_id = await self.store.get_events(auth_events_ids)
99+
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
98100

99101
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
100102
event_auth.check(
@@ -151,17 +153,11 @@ async def check_user_in_room(
151153

152154
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
153155

154-
async def check_host_in_room(self, room_id, host):
156+
async def check_host_in_room(self, room_id: str, host: str) -> bool:
155157
with Measure(self.clock, "check_host_in_room"):
156-
latest_event_ids = await self.store.is_host_joined(room_id, host)
157-
return latest_event_ids
158-
159-
def can_federate(self, event, auth_events):
160-
creation_event = auth_events.get((EventTypes.Create, ""))
158+
return await self.store.is_host_joined(room_id, host)
161159

162-
return creation_event.content.get("m.federate", True) is True
163-
164-
def get_public_keys(self, invite_event):
160+
def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
165161
return event_auth.get_public_keys(invite_event)
166162

167163
async def get_user_by_req(
@@ -170,7 +166,7 @@ async def get_user_by_req(
170166
allow_guest: bool = False,
171167
rights: str = "access",
172168
allow_expired: bool = False,
173-
) -> synapse.types.Requester:
169+
) -> Requester:
174170
"""Get a registered user's ID.
175171
176172
Args:
@@ -196,7 +192,7 @@ async def get_user_by_req(
196192
access_token = self.get_access_token_from_request(request)
197193

198194
user_id, app_service = await self._get_appservice_user_id(request)
199-
if user_id:
195+
if user_id and app_service:
200196
if ip_addr and self._track_appservice_user_ips:
201197
await self.store.insert_client_ip(
202198
user_id=user_id,
@@ -206,9 +202,7 @@ async def get_user_by_req(
206202
device_id="dummy-device", # stubbed
207203
)
208204

209-
requester = synapse.types.create_requester(
210-
user_id, app_service=app_service
211-
)
205+
requester = create_requester(user_id, app_service=app_service)
212206

213207
request.requester = user_id
214208
opentracing.set_tag("authenticated_entity", user_id)
@@ -251,7 +245,7 @@ async def get_user_by_req(
251245
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
252246
)
253247

254-
requester = synapse.types.create_requester(
248+
requester = create_requester(
255249
user_info.user_id,
256250
token_id,
257251
is_guest,
@@ -271,7 +265,9 @@ async def get_user_by_req(
271265
except KeyError:
272266
raise MissingClientTokenError()
273267

274-
async def _get_appservice_user_id(self, request):
268+
async def _get_appservice_user_id(
269+
self, request: Request
270+
) -> Tuple[Optional[str], Optional[ApplicationService]]:
275271
app_service = self.store.get_app_service_by_token(
276272
self.get_access_token_from_request(request)
277273
)
@@ -283,6 +279,9 @@ async def _get_appservice_user_id(self, request):
283279
if ip_address not in app_service.ip_range_whitelist:
284280
return None, None
285281

282+
# This will always be set by the time Twisted calls us.
283+
assert request.args is not None
284+
286285
if b"user_id" not in request.args:
287286
return app_service.sender, app_service
288287

@@ -387,7 +386,9 @@ async def get_user_by_access_token(
387386
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
388387
raise InvalidClientTokenError("Invalid macaroon passed.")
389388

390-
def _parse_and_validate_macaroon(self, token, rights="access"):
389+
def _parse_and_validate_macaroon(
390+
self, token: str, rights: str = "access"
391+
) -> Tuple[str, bool]:
391392
"""Takes a macaroon and tries to parse and validate it. This is cached
392393
if and only if rights == access and there isn't an expiry.
393394
@@ -432,15 +433,16 @@ def _parse_and_validate_macaroon(self, token, rights="access"):
432433

433434
return user_id, guest
434435

435-
def validate_macaroon(self, macaroon, type_string, user_id):
436+
def validate_macaroon(
437+
self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
438+
) -> None:
436439
"""
437440
validate that a Macaroon is understood by and was signed by this server.
438441
439442
Args:
440-
macaroon(pymacaroons.Macaroon): The macaroon to validate
441-
type_string(str): The kind of token required (e.g. "access",
442-
"delete_pusher")
443-
user_id (str): The user_id required
443+
macaroon: The macaroon to validate
444+
type_string: The kind of token required (e.g. "access", "delete_pusher")
445+
user_id: The user_id required
444446
"""
445447
v = pymacaroons.Verifier()
446448

@@ -465,9 +467,7 @@ def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
465467
if not service:
466468
logger.warning("Unrecognised appservice access token.")
467469
raise InvalidClientTokenError()
468-
request.requester = synapse.types.create_requester(
469-
service.sender, app_service=service
470-
)
470+
request.requester = create_requester(service.sender, app_service=service)
471471
return service
472472

473473
async def is_server_admin(self, user: UserID) -> bool:
@@ -519,7 +519,7 @@ def compute_auth_events(
519519

520520
return auth_ids
521521

522-
async def check_can_change_room_list(self, room_id: str, user: UserID):
522+
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
523523
"""Determine whether the user is allowed to edit the room's entry in the
524524
published room list.
525525
@@ -554,11 +554,11 @@ async def check_can_change_room_list(self, room_id: str, user: UserID):
554554
return user_level >= send_level
555555

556556
@staticmethod
557-
def has_access_token(request: Request):
557+
def has_access_token(request: Request) -> bool:
558558
"""Checks if the request has an access_token.
559559
560560
Returns:
561-
bool: False if no access_token was given, True otherwise.
561+
False if no access_token was given, True otherwise.
562562
"""
563563
# This will always be set by the time Twisted calls us.
564564
assert request.args is not None
@@ -568,13 +568,13 @@ def has_access_token(request: Request):
568568
return bool(query_params) or bool(auth_headers)
569569

570570
@staticmethod
571-
def get_access_token_from_request(request: Request):
571+
def get_access_token_from_request(request: Request) -> str:
572572
"""Extracts the access_token from the request.
573573
574574
Args:
575575
request: The http request.
576576
Returns:
577-
unicode: The access_token
577+
The access_token
578578
Raises:
579579
MissingClientTokenError: If there isn't a single access_token in the
580580
request
@@ -649,5 +649,5 @@ async def check_user_in_room_or_world_readable(
649649
% (user_id, room_id),
650650
)
651651

652-
def check_auth_blocking(self, *args, **kwargs):
653-
return self._auth_blocking.check_auth_blocking(*args, **kwargs)
652+
async def check_auth_blocking(self, *args, **kwargs) -> None:
653+
await self._auth_blocking.check_auth_blocking(*args, **kwargs)

synapse/api/auth_blocking.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,21 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Optional
16+
from typing import TYPE_CHECKING, Optional
1717

1818
from synapse.api.constants import LimitBlockingTypes, UserTypes
1919
from synapse.api.errors import Codes, ResourceLimitError
2020
from synapse.config.server import is_threepid_reserved
2121
from synapse.types import Requester
2222

23+
if TYPE_CHECKING:
24+
from synapse.server import HomeServer
25+
2326
logger = logging.getLogger(__name__)
2427

2528

2629
class AuthBlocking:
27-
def __init__(self, hs):
30+
def __init__(self, hs: "HomeServer"):
2831
self.store = hs.get_datastore()
2932

3033
self._server_notices_mxid = hs.config.server_notices_mxid
@@ -43,7 +46,7 @@ async def check_auth_blocking(
4346
threepid: Optional[dict] = None,
4447
user_type: Optional[str] = None,
4548
requester: Optional[Requester] = None,
46-
):
49+
) -> None:
4750
"""Checks if the user should be rejected for some external reason,
4851
such as monthly active user limiting or global disable flag
4952

synapse/event_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import List, Optional, Set, Tuple
17+
from typing import Any, Dict, List, Optional, Set, Tuple
1818

1919
from canonicaljson import encode_canonical_json
2020
from signedjson.key import decode_verify_key_bytes
@@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
688688
return False
689689

690690

691-
def get_public_keys(invite_event):
691+
def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
692692
public_keys = []
693693
if "public_key" in invite_event.content:
694694
o = {"public_key": invite_event.content["public_key"]}

0 commit comments

Comments
 (0)