Skip to content

Commit d790d0d

Browse files
authored
Add type hints to user admin API. (matrix-org#9521)
1 parent 0c33042 commit d790d0d

File tree

4 files changed

+63
-35
lines changed

4 files changed

+63
-35
lines changed

changelog.d/9521.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add type hints to user admin API.

synapse/rest/admin/users.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import hmac
1717
import logging
1818
from http import HTTPStatus
19-
from typing import TYPE_CHECKING, Tuple
19+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
2020

2121
from synapse.api.constants import UserTypes
2222
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -47,13 +47,15 @@
4747
class UsersRestServlet(RestServlet):
4848
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
4949

50-
def __init__(self, hs):
50+
def __init__(self, hs: "HomeServer"):
5151
self.hs = hs
5252
self.store = hs.get_datastore()
5353
self.auth = hs.get_auth()
5454
self.admin_handler = hs.get_admin_handler()
5555

56-
async def on_GET(self, request, user_id):
56+
async def on_GET(
57+
self, request: SynapseRequest, user_id: str
58+
) -> Tuple[int, List[JsonDict]]:
5759
target_user = UserID.from_string(user_id)
5860
await assert_requester_is_admin(self.auth, request)
5961

@@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
153155
otherwise an error.
154156
"""
155157

156-
def __init__(self, hs):
158+
def __init__(self, hs: "HomeServer"):
157159
self.hs = hs
158160
self.auth = hs.get_auth()
159161
self.admin_handler = hs.get_admin_handler()
@@ -165,7 +167,9 @@ def __init__(self, hs):
165167
self.registration_handler = hs.get_registration_handler()
166168
self.pusher_pool = hs.get_pusherpool()
167169

168-
async def on_GET(self, request, user_id):
170+
async def on_GET(
171+
self, request: SynapseRequest, user_id: str
172+
) -> Tuple[int, JsonDict]:
169173
await assert_requester_is_admin(self.auth, request)
170174

171175
target_user = UserID.from_string(user_id)
@@ -179,7 +183,9 @@ async def on_GET(self, request, user_id):
179183

180184
return 200, ret
181185

182-
async def on_PUT(self, request, user_id):
186+
async def on_PUT(
187+
self, request: SynapseRequest, user_id: str
188+
) -> Tuple[int, JsonDict]:
183189
requester = await self.auth.get_user_by_req(request)
184190
await assert_user_is_admin(self.auth, requester.user)
185191

@@ -273,6 +279,8 @@ async def on_PUT(self, request, user_id):
273279
)
274280

275281
user = await self.admin_handler.get_user(target_user)
282+
assert user is not None
283+
276284
return 200, user
277285

278286
else: # create user
@@ -330,9 +338,10 @@ async def on_PUT(self, request, user_id):
330338
target_user, requester, body["avatar_url"], True
331339
)
332340

333-
ret = await self.admin_handler.get_user(target_user)
341+
user = await self.admin_handler.get_user(target_user)
342+
assert user is not None
334343

335-
return 201, ret
344+
return 201, user
336345

337346

338347
class UserRegisterServlet(RestServlet):
@@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet):
346355
PATTERNS = admin_patterns("/register")
347356
NONCE_TIMEOUT = 60
348357

349-
def __init__(self, hs):
358+
def __init__(self, hs: "HomeServer"):
350359
self.auth_handler = hs.get_auth_handler()
351360
self.reactor = hs.get_reactor()
352-
self.nonces = {}
361+
self.nonces = {} # type: Dict[str, int]
353362
self.hs = hs
354363

355364
def _clear_old_nonces(self):
@@ -362,7 +371,7 @@ def _clear_old_nonces(self):
362371
if now - v > self.NONCE_TIMEOUT:
363372
del self.nonces[k]
364373

365-
def on_GET(self, request):
374+
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
366375
"""
367376
Generate a new nonce.
368377
"""
@@ -372,7 +381,7 @@ def on_GET(self, request):
372381
self.nonces[nonce] = int(self.reactor.seconds())
373382
return 200, {"nonce": nonce}
374383

375-
async def on_POST(self, request):
384+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
376385
self._clear_old_nonces()
377386

378387
if not self.hs.config.registration_shared_secret:
@@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet):
478487
client_patterns("/admin" + path_regex, v1=True)
479488
)
480489

481-
def __init__(self, hs):
490+
def __init__(self, hs: "HomeServer"):
482491
self.hs = hs
483492
self.auth = hs.get_auth()
484493
self.admin_handler = hs.get_admin_handler()
485494

486-
async def on_GET(self, request, user_id):
495+
async def on_GET(
496+
self, request: SynapseRequest, user_id: str
497+
) -> Tuple[int, JsonDict]:
487498
target_user = UserID.from_string(user_id)
488499
requester = await self.auth.get_user_by_req(request)
489500
auth_user = requester.user
@@ -508,7 +519,9 @@ def __init__(self, hs: "HomeServer"):
508519
self.is_mine = hs.is_mine
509520
self.store = hs.get_datastore()
510521

511-
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
522+
async def on_POST(
523+
self, request: SynapseRequest, target_user_id: str
524+
) -> Tuple[int, JsonDict]:
512525
requester = await self.auth.get_user_by_req(request)
513526
await assert_user_is_admin(self.auth, requester.user)
514527

@@ -550,7 +563,7 @@ def __init__(self, hs):
550563
self.account_activity_handler = hs.get_account_validity_handler()
551564
self.auth = hs.get_auth()
552565

553-
async def on_POST(self, request):
566+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
554567
await assert_requester_is_admin(self.auth, request)
555568

556569
body = parse_json_object_from_request(request)
@@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
584597

585598
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
586599

587-
def __init__(self, hs):
600+
def __init__(self, hs: "HomeServer"):
588601
self.store = hs.get_datastore()
589602
self.hs = hs
590603
self.auth = hs.get_auth()
591604
self.auth_handler = hs.get_auth_handler()
592605
self._set_password_handler = hs.get_set_password_handler()
593606

594-
async def on_POST(self, request, target_user_id):
607+
async def on_POST(
608+
self, request: SynapseRequest, target_user_id: str
609+
) -> Tuple[int, JsonDict]:
595610
"""Post request to allow an administrator reset password for a user.
596611
This needs user to have administrator access in Synapse.
597612
"""
@@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
626641

627642
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
628643

629-
def __init__(self, hs):
644+
def __init__(self, hs: "HomeServer"):
630645
self.hs = hs
631646
self.store = hs.get_datastore()
632647
self.auth = hs.get_auth()
633648

634-
async def on_GET(self, request, target_user_id):
649+
async def on_GET(
650+
self, request: SynapseRequest, target_user_id: str
651+
) -> Tuple[int, Optional[List[JsonDict]]]:
635652
"""Get request to search user table for specific users according to
636653
search term.
637654
This needs user to have a administrator access in Synapse.
@@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet):
682699

683700
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
684701

685-
def __init__(self, hs):
702+
def __init__(self, hs: "HomeServer"):
686703
self.hs = hs
687704
self.store = hs.get_datastore()
688705
self.auth = hs.get_auth()
689706

690-
async def on_GET(self, request, user_id):
707+
async def on_GET(
708+
self, request: SynapseRequest, user_id: str
709+
) -> Tuple[int, JsonDict]:
691710
await assert_requester_is_admin(self.auth, request)
692711

693712
target_user = UserID.from_string(user_id)
@@ -699,7 +718,9 @@ async def on_GET(self, request, user_id):
699718

700719
return 200, {"admin": is_admin}
701720

702-
async def on_PUT(self, request, user_id):
721+
async def on_PUT(
722+
self, request: SynapseRequest, user_id: str
723+
) -> Tuple[int, JsonDict]:
703724
requester = await self.auth.get_user_by_req(request)
704725
await assert_user_is_admin(self.auth, requester.user)
705726
auth_user = requester.user
@@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
730751

731752
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
732753

733-
def __init__(self, hs):
754+
def __init__(self, hs: "HomeServer"):
734755
self.is_mine = hs.is_mine
735756
self.auth = hs.get_auth()
736757
self.store = hs.get_datastore()
737758

738-
async def on_GET(self, request, user_id):
759+
async def on_GET(
760+
self, request: SynapseRequest, user_id: str
761+
) -> Tuple[int, JsonDict]:
739762
await assert_requester_is_admin(self.auth, request)
740763

741764
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet):
758781

759782
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
760783

761-
def __init__(self, hs):
784+
def __init__(self, hs: "HomeServer"):
762785
self.is_mine = hs.is_mine
763786
self.store = hs.get_datastore()
764787
self.auth = hs.get_auth()
@@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
799822

800823
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
801824

802-
def __init__(self, hs):
825+
def __init__(self, hs: "HomeServer"):
803826
self.is_mine = hs.is_mine
804827
self.auth = hs.get_auth()
805828
self.store = hs.get_datastore()
@@ -891,7 +914,9 @@ def __init__(self, hs: "HomeServer"):
891914
self.auth = hs.get_auth()
892915
self.auth_handler = hs.get_auth_handler()
893916

894-
async def on_POST(self, request, user_id):
917+
async def on_POST(
918+
self, request: SynapseRequest, user_id: str
919+
) -> Tuple[int, JsonDict]:
895920
requester = await self.auth.get_user_by_req(request)
896921
await assert_user_is_admin(self.auth, requester.user)
897922
auth_user = requester.user
@@ -943,7 +968,9 @@ def __init__(self, hs: "HomeServer"):
943968
self.store = hs.get_datastore()
944969
self.auth = hs.get_auth()
945970

946-
async def on_POST(self, request, user_id):
971+
async def on_POST(
972+
self, request: SynapseRequest, user_id: str
973+
) -> Tuple[int, JsonDict]:
947974
await assert_requester_is_admin(self.auth, request)
948975

949976
if not self.hs.is_mine_id(user_id):

synapse/storage/databases/main/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717

1818
import logging
19-
from typing import Any, Dict, List, Optional, Tuple
19+
from typing import List, Optional, Tuple
2020

2121
from synapse.api.constants import PresenceState
2222
from synapse.config.homeserver import HomeServerConfig
@@ -27,7 +27,7 @@
2727
MultiWriterIdGenerator,
2828
StreamIdGenerator,
2929
)
30-
from synapse.types import get_domain_from_id
30+
from synapse.types import JsonDict, get_domain_from_id
3131
from synapse.util.caches.stream_change_cache import StreamChangeCache
3232

3333
from .account_data import AccountDataStore
@@ -264,7 +264,7 @@ def _get_active_presence(self, db_conn):
264264

265265
return [UserPresenceState(**row) for row in rows]
266266

267-
async def get_users(self) -> List[Dict[str, Any]]:
267+
async def get_users(self) -> List[JsonDict]:
268268
"""Function to retrieve a list of users in users table.
269269
270270
Returns:
@@ -292,7 +292,7 @@ async def get_users_paginate(
292292
name: Optional[str] = None,
293293
guests: bool = True,
294294
deactivated: bool = False,
295-
) -> Tuple[List[Dict[str, Any]], int]:
295+
) -> Tuple[List[JsonDict], int]:
296296
"""Function to retrieve a paginated list of users from
297297
users list. This will return a json list of users and the
298298
total number of users matching the filter criteria.
@@ -353,7 +353,7 @@ def get_users_paginate_txn(txn):
353353
"get_users_paginate_txn", get_users_paginate_txn
354354
)
355355

356-
async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
356+
async def search_users(self, term: str) -> Optional[List[JsonDict]]:
357357
"""Function to search users list for one or more users with
358358
the matched term.
359359

synapse/storage/databases/main/media_repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ async def get_local_media_by_user_paginate(
139139
start: int,
140140
limit: int,
141141
user_id: str,
142-
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
142+
order_by: str = MediaSortOrder.CREATED_TS.value,
143143
direction: str = "f",
144144
) -> Tuple[List[Dict[str, Any]], int]:
145145
"""Get a paginated list of metadata for a local piece of media

0 commit comments

Comments
 (0)