diff --git a/changelog.d/18520.misc b/changelog.d/18520.misc new file mode 100644 index 00000000000..d005d3b7f78 --- /dev/null +++ b/changelog.d/18520.misc @@ -0,0 +1 @@ +Dedicated internal API for Matrix Authentication Service to Synapse communication. diff --git a/synapse/_pydantic_compat.py b/synapse/_pydantic_compat.py index f0eedf5c6d1..e9b43aebe32 100644 --- a/synapse/_pydantic_compat.py +++ b/synapse/_pydantic_compat.py @@ -48,6 +48,7 @@ conint, constr, parse_obj_as, + root_validator, validator, ) from pydantic.v1.error_wrappers import ErrorWrapper @@ -68,6 +69,7 @@ conint, constr, parse_obj_as, + root_validator, validator, ) from pydantic.error_wrappers import ErrorWrapper @@ -92,4 +94,5 @@ "StrictStr", "ValidationError", "validator", + "root_validator", ) diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index a584ef9ab37..567f2e834ce 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -369,6 +369,12 @@ async def _introspect_token( async def is_server_admin(self, requester: Requester) -> bool: return "urn:synapse:admin:*" in requester.scope + def _is_access_token_the_admin_token(self, token: str) -> bool: + admin_token = self._admin_token() + if admin_token is None: + return False + return token == admin_token + async def get_user_by_req( self, request: SynapseRequest, @@ -434,7 +440,7 @@ async def _wrapped_get_user_by_req( requester = await self.get_user_by_access_token(access_token, allow_expired) # Do not record requests from MAS using the virtual `__oidc_admin` user. - if access_token != self._admin_token(): + if not self._is_access_token_the_admin_token(access_token): await self._record_request(request, requester) if not allow_guest and requester.is_guest: @@ -470,13 +476,25 @@ async def get_user_by_req_experimental_feature( raise UnrecognizedRequestError(code=404) + def is_request_using_the_admin_token(self, request: SynapseRequest) -> bool: + """ + Check if the request is using the admin token. + + Args: + request: The request to check. + + Returns: + True if the request is using the admin token, False otherwise. + """ + access_token = self.get_access_token_from_request(request) + return self._is_access_token_the_admin_token(access_token) + async def get_user_by_access_token( self, token: str, allow_expired: bool = False, ) -> Requester: - admin_token = self._admin_token() - if admin_token is not None and token == admin_token: + if self._is_access_token_the_admin_token(token): # XXX: This is a temporary solution so that the admin API can be called by # the OIDC provider. This will be removed once we have OIDC client # credentials grant support in matrix-authentication-service. diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 7b5bfc0421e..043c5083799 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource +from synapse.rest.synapse.mas import MasResource if TYPE_CHECKING: from synapse.server import HomeServer @@ -60,6 +61,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc from synapse.rest.synapse.client.jwks import JwksResource resources["/_synapse/jwks"] = JwksResource(hs) + resources["/_synapse/mas"] = MasResource(hs) # provider-specific SSO bits. Only load these if they are enabled, since they # rely on optional dependencies. diff --git a/synapse/rest/synapse/mas/__init__.py b/synapse/rest/synapse/mas/__init__.py new file mode 100644 index 00000000000..8115c563d28 --- /dev/null +++ b/synapse/rest/synapse/mas/__init__.py @@ -0,0 +1,71 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# + + +import logging +from typing import TYPE_CHECKING + +from twisted.web.resource import Resource + +from synapse.rest.synapse.mas.devices import ( + MasDeleteDeviceResource, + MasSyncDevicesResource, + MasUpdateDeviceDisplayNameResource, + MasUpsertDeviceResource, +) +from synapse.rest.synapse.mas.users import ( + MasAllowCrossSigningResetResource, + MasDeleteUserResource, + MasIsLocalpartAvailableResource, + MasProvisionUserResource, + MasQueryUserResource, + MasReactivateUserResource, + MasSetDisplayNameResource, + MasUnsetDisplayNameResource, +) + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class MasResource(Resource): + """ + Provides endpoints for MAS to manage user accounts and devices. + + All endpoints are mounted under the path `/_synapse/mas/` and only work + using the MAS admin token. + """ + + def __init__(self, hs: "HomeServer"): + Resource.__init__(self) + self.putChild(b"query_user", MasQueryUserResource(hs)) + self.putChild(b"provision_user", MasProvisionUserResource(hs)) + self.putChild(b"is_localpart_available", MasIsLocalpartAvailableResource(hs)) + self.putChild(b"delete_user", MasDeleteUserResource(hs)) + self.putChild(b"upsert_device", MasUpsertDeviceResource(hs)) + self.putChild(b"delete_device", MasDeleteDeviceResource(hs)) + self.putChild( + b"update_device_display_name", MasUpdateDeviceDisplayNameResource(hs) + ) + self.putChild(b"sync_devices", MasSyncDevicesResource(hs)) + self.putChild(b"reactivate_user", MasReactivateUserResource(hs)) + self.putChild(b"set_displayname", MasSetDisplayNameResource(hs)) + self.putChild(b"unset_displayname", MasUnsetDisplayNameResource(hs)) + self.putChild( + b"allow_cross_signing_reset", MasAllowCrossSigningResetResource(hs) + ) diff --git a/synapse/rest/synapse/mas/_base.py b/synapse/rest/synapse/mas/_base.py new file mode 100644 index 00000000000..caf392fc3ae --- /dev/null +++ b/synapse/rest/synapse/mas/_base.py @@ -0,0 +1,47 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# + + +from typing import TYPE_CHECKING, cast + +from synapse.api.errors import SynapseError +from synapse.http.server import DirectServeJsonResource + +if TYPE_CHECKING: + from synapse.app.generic_worker import GenericWorkerStore + from synapse.http.site import SynapseRequest + from synapse.server import HomeServer + + +class MasBaseResource(DirectServeJsonResource): + def __init__(self, hs: "HomeServer"): + # Importing this module requires authlib, which is an optional + # dependency but required if msc3861 is enabled + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + DirectServeJsonResource.__init__(self, extract_context=True) + auth = hs.get_auth() + assert isinstance(auth, MSC3861DelegatedAuth) + self.msc3861_auth = auth + self.store = cast("GenericWorkerStore", hs.get_datastores().main) + self.hostname = hs.hostname + + def assert_request_is_from_mas(self, request: "SynapseRequest") -> None: + """Assert that the request is coming from MAS itself, not a regular user. + + Throws a 403 if the request is not coming from MAS. + """ + if not self.msc3861_auth.is_request_using_the_admin_token(request): + raise SynapseError(403, "This endpoint must only be called by MAS") diff --git a/synapse/rest/synapse/mas/devices.py b/synapse/rest/synapse/mas/devices.py new file mode 100644 index 00000000000..6cc11535906 --- /dev/null +++ b/synapse/rest/synapse/mas/devices.py @@ -0,0 +1,238 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# + +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Optional, Tuple + +from synapse._pydantic_compat import StrictStr +from synapse.api.errors import NotFoundError +from synapse.http.servlet import parse_and_validate_json_object_from_request +from synapse.types import JsonDict, UserID +from synapse.types.rest import RequestBodyModel + +if TYPE_CHECKING: + from synapse.http.site import SynapseRequest + from synapse.server import HomeServer + + +from ._base import MasBaseResource + +logger = logging.getLogger(__name__) + + +class MasUpsertDeviceResource(MasBaseResource): + """ + Endpoint for MAS to create or update user devices. + + Takes a localpart, device ID, and optional display name to create new devices + or update existing ones. + + POST /_synapse/mas/upsert_device + {"localpart": "alice", "device_id": "DEVICE123", "display_name": "Alice's Phone"} + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + self.device_handler = hs.get_device_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + device_id: StrictStr + display_name: Optional[StrictStr] + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + inserted = await self.device_handler.upsert_device( + user_id=str(user_id), + device_id=body.device_id, + display_name=body.display_name, + ) + + return HTTPStatus.CREATED if inserted else HTTPStatus.OK, {} + + +class MasDeleteDeviceResource(MasBaseResource): + """ + Endpoint for MAS to delete user devices. + + Takes a localpart and device ID to remove the specified device from the user's account. + + POST /_synapse/mas/delete_device + {"localpart": "alice", "device_id": "DEVICE123"} + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + self.device_handler = hs.get_device_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + device_id: StrictStr + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + await self.device_handler.delete_devices( + user_id=str(user_id), + device_ids=[body.device_id], + ) + + return HTTPStatus.NO_CONTENT, {} + + +class MasUpdateDeviceDisplayNameResource(MasBaseResource): + """ + Endpoint for MAS to update a device's display name. + + Takes a localpart, device ID, and new display name to update the device's name. + + POST /_synapse/mas/update_device_display_name + {"localpart": "alice", "device_id": "DEVICE123", "display_name": "Alice's New Phone"} + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + self.device_handler = hs.get_device_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + device_id: StrictStr + display_name: StrictStr + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + await self.device_handler.update_device( + user_id=str(user_id), + device_id=body.device_id, + content={"display_name": body.display_name}, + ) + + return HTTPStatus.OK, {} + + +class MasSyncDevicesResource(MasBaseResource): + """ + Endpoint for MAS to synchronize a user's complete device list. + + Takes a localpart and a set of device IDs to ensure the user's device list + matches the provided set by adding missing devices and removing extra ones. + + POST /_synapse/mas/sync_devices + {"localpart": "alice", "devices": ["DEVICE123", "DEVICE456"]} + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + self.device_handler = hs.get_device_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + devices: set[StrictStr] + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + current_devices = await self.store.get_devices_by_user(user_id=str(user_id)) + current_devices_list = set(current_devices.keys()) + target_device_list = set(body.devices) + + to_add = target_device_list - current_devices_list + to_delete = current_devices_list - target_device_list + + # Log what we're about to do to make it easier to debug if it stops + # mid-way, as this can be a long operation if there are a lot of devices + # to delete or to add. + if to_add and to_delete: + logger.info( + "Syncing %d devices for user %s will add %d devices and delete %d devices", + len(target_device_list), + user_id, + len(to_add), + len(to_delete), + ) + elif to_add: + logger.info( + "Syncing %d devices for user %s will add %d devices", + len(target_device_list), + user_id, + len(to_add), + ) + elif to_delete: + logger.info( + "Syncing %d devices for user %s will delete %d devices", + len(target_device_list), + user_id, + len(to_delete), + ) + + if to_delete: + await self.device_handler.delete_devices( + user_id=str(user_id), device_ids=to_delete + ) + + for device_id in to_add: + await self.device_handler.upsert_device( + user_id=str(user_id), + device_id=device_id, + ) + + return 200, {} diff --git a/synapse/rest/synapse/mas/users.py b/synapse/rest/synapse/mas/users.py new file mode 100644 index 00000000000..09aa13bebbc --- /dev/null +++ b/synapse/rest/synapse/mas/users.py @@ -0,0 +1,467 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# + +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Optional, Tuple, TypedDict + +from synapse._pydantic_compat import StrictBool, StrictStr, root_validator +from synapse.api.errors import NotFoundError, SynapseError +from synapse.http.servlet import ( + parse_and_validate_json_object_from_request, + parse_string, +) +from synapse.types import JsonDict, UserID, UserInfo, create_requester +from synapse.types.rest import RequestBodyModel + +if TYPE_CHECKING: + from synapse.http.site import SynapseRequest + from synapse.server import HomeServer + + +from ._base import MasBaseResource + +logger = logging.getLogger(__name__) + + +class MasQueryUserResource(MasBaseResource): + """ + Endpoint for MAS to query user information by localpart. + + Takes a localpart parameter and returns user profile data including display name, + avatar URL, and account status (suspended/deactivated). + + GET /_synapse/mas/query_user?localpart=alice + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + class Response(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] + is_suspended: bool + is_deactivated: bool + + async def _async_render_GET( + self, request: "SynapseRequest" + ) -> Tuple[int, Response]: + self.assert_request_is_from_mas(request) + + localpart = parse_string(request, "localpart", required=True) + user_id = UserID(localpart, self.hostname) + + user: Optional[UserInfo] = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + profile = await self.store.get_profileinfo(user_id=user_id) + + return HTTPStatus.OK, self.Response( + user_id=user_id.to_string(), + display_name=profile.display_name, + avatar_url=profile.avatar_url, + is_suspended=user.suspended, + is_deactivated=user.is_deactivated, + ) + + +class MasProvisionUserResource(MasBaseResource): + """ + Endpoint for MAS to create or update user accounts and their profile data. + + Takes a localpart and optional profile fields (display name, avatar URL, email addresses). + Can create new users or update existing ones by setting or unsetting profile fields. + + POST /_synapse/mas/provision_user + {"localpart": "alice", "set_displayname": "Alice", "set_emails": ["alice@example.com"]} + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + self.registration_handler = hs.get_registration_handler() + self.identity_handler = hs.get_identity_handler() + self.auth_handler = hs.get_auth_handler() + self.profile_handler = hs.get_profile_handler() + self.clock = hs.get_clock() + self.auth = hs.get_auth() + + class PostBody(RequestBodyModel): + localpart: StrictStr + + unset_displayname: StrictBool = False + set_displayname: Optional[StrictStr] = None + + unset_avatar_url: StrictBool = False + set_avatar_url: Optional[StrictStr] = None + + unset_emails: StrictBool = False + set_emails: Optional[list[StrictStr]] = None + + @root_validator(pre=True) + def validate_exclusive(cls, values: Any) -> Any: + if "unset_displayname" in values and "set_displayname" in values: + raise ValueError( + "Cannot specify both unset_displayname and set_displayname" + ) + if "unset_avatar_url" in values and "set_avatar_url" in values: + raise ValueError( + "Cannot specify both unset_avatar_url and set_avatar_url" + ) + if "unset_emails" in values and "set_emails" in values: + raise ValueError("Cannot specify both unset_emails and set_emails") + + return values + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + + localpart = body.localpart + user_id = UserID(localpart, self.hostname) + + requester = create_requester(user_id=user_id) + existing_user = await self.store.get_user_by_id(user_id=str(user_id)) + if existing_user is None: + created = True + await self.registration_handler.register_user( + localpart=localpart, + default_display_name=body.set_displayname, + bind_emails=body.set_emails, + by_admin=True, + ) + else: + created = False + if body.unset_displayname: + await self.profile_handler.set_displayname( + target_user=user_id, + requester=requester, + new_displayname="", + by_admin=True, + ) + elif body.set_displayname is not None: + await self.profile_handler.set_displayname( + target_user=user_id, + requester=requester, + new_displayname=body.set_displayname, + by_admin=True, + ) + + new_email_list: Optional[set[str]] = None + if body.unset_emails: + new_email_list = set() + elif body.set_emails is not None: + new_email_list = set(body.set_emails) + + if new_email_list is not None: + medium = "email" + current_threepid_list = await self.store.user_get_threepids( + user_id=user_id.to_string() + ) + current_email_list = { + t.address for t in current_threepid_list if t.medium == medium + } + + to_delete = current_email_list - new_email_list + to_add = new_email_list - current_email_list + + for address in to_delete: + await self.identity_handler.try_unbind_threepid( + mxid=user_id.to_string(), + medium=medium, + address=address, + id_server=None, + ) + + await self.auth_handler.delete_local_threepid( + user_id=user_id.to_string(), + medium=medium, + address=address, + ) + + current_time = self.clock.time_msec() + for address in to_add: + await self.auth_handler.add_threepid( + user_id=user_id.to_string(), + medium=medium, + address=address, + validated_at=current_time, + ) + + if body.unset_avatar_url: + await self.profile_handler.set_avatar_url( + target_user=user_id, + requester=requester, + new_avatar_url="", + by_admin=True, + ) + elif body.set_avatar_url is not None: + await self.profile_handler.set_avatar_url( + target_user=user_id, + requester=requester, + new_avatar_url=body.set_avatar_url, + by_admin=True, + ) + + return HTTPStatus.CREATED if created else HTTPStatus.OK, {} + + +class MasIsLocalpartAvailableResource(MasBaseResource): + """ + Endpoint for MAS to check if a localpart is available for user registration. + + Takes a localpart parameter and validates its format and availability, + checking for conflicts with existing users or application service namespaces. + + GET /_synapse/mas/is_localpart_available?localpart=alice + """ + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.registration_handler = hs.get_registration_handler() + + async def _async_render_GET( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + localpart = parse_string(request, "localpart") + if localpart is None: + raise SynapseError(400, "Missing localpart") + + await self.registration_handler.check_username(localpart) + + return HTTPStatus.OK, {} + + +class MasDeleteUserResource(MasBaseResource): + """ + Endpoint for MAS to delete/deactivate user accounts. + + Takes a localpart and an erase flag to determine whether to deactivate + the account and optionally erase user data for compliance purposes. + + POST /_synapse/mas/delete_user + {"localpart": "alice", "erase": true} + """ + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.deactivate_account_handler = hs.get_deactivate_account_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + erase: StrictBool + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + await self.deactivate_account_handler.deactivate_account( + user_id=user_id.to_string(), + erase_data=body.erase, + requester=create_requester(user_id=user_id), + ) + + return HTTPStatus.OK, {} + + +class MasReactivateUserResource(MasBaseResource): + """ + Endpoint for MAS to reactivate previously deactivated user accounts. + + Takes a localpart parameter to restore access to deactivated accounts. + + POST /_synapse/mas/reactivate_user + {"localpart": "alice"} + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + self.deactivate_account_handler = hs.get_deactivate_account_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + await self.deactivate_account_handler.activate_account(user_id=str(user_id)) + + return HTTPStatus.OK, {} + + +class MasSetDisplayNameResource(MasBaseResource): + """ + Endpoint for MAS to set a user's display name. + + Takes a localpart and display name to update the user's profile. + + POST /_synapse/mas/set_displayname + {"localpart": "alice", "displayname": "Alice"} + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + self.profile_handler = hs.get_profile_handler() + self.auth_handler = hs.get_auth_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + displayname: StrictStr + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + requester = create_requester(user_id=user_id) + + await self.profile_handler.set_displayname( + target_user=requester.user, + requester=requester, + new_displayname=body.displayname, + by_admin=True, + ) + + return HTTPStatus.OK, {} + + +class MasUnsetDisplayNameResource(MasBaseResource): + """ + Endpoint for MAS to clear a user's display name. + + Takes a localpart parameter to remove the display name for the specified user. + + POST /_synapse/mas/unset_displayname + {"localpart": "alice"} + """ + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + self.profile_handler = hs.get_profile_handler() + self.auth_handler = hs.get_auth_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + requester = create_requester(user_id=user_id) + + await self.profile_handler.set_displayname( + target_user=requester.user, + requester=requester, + new_displayname="", + by_admin=True, + ) + + return HTTPStatus.OK, {} + + +class MasAllowCrossSigningResetResource(MasBaseResource): + """ + Endpoint for MAS to allow cross-signing key reset without user interaction. + + Takes a localpart parameter to temporarily allow cross-signing key replacement + without requiring User-Interactive Authentication (UIA). + + POST /_synapse/mas/allow_cross_signing_reset + {"localpart": "alice"} + """ + + REPLACEMENT_PERIOD_MS = 10 * 60 * 1000 # 10 minutes + + def __init__(self, hs: "HomeServer"): + MasBaseResource.__init__(self, hs) + + self.auth_handler = hs.get_auth_handler() + + class PostBody(RequestBodyModel): + localpart: StrictStr + + async def _async_render_POST( + self, request: "SynapseRequest" + ) -> Tuple[int, JsonDict]: + self.assert_request_is_from_mas(request) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + user_id = UserID(body.localpart, self.hostname) + + # Check the user exists + user = await self.store.get_user_by_id(user_id=str(user_id)) + if user is None: + raise NotFoundError("User not found") + + timestamp = ( + await self.store.allow_master_cross_signing_key_replacement_without_uia( + user_id=str(user_id), + duration_ms=self.REPLACEMENT_PERIOD_MS, + ) + ) + + if timestamp is None: + # If there are no cross-signing keys, this is a no-op, but we should log + logger.warning( + "User %s has no master cross-signing key", user_id.to_string() + ) + + return HTTPStatus.OK, {} diff --git a/tests/rest/synapse/mas/__init__.py b/tests/rest/synapse/mas/__init__.py new file mode 100644 index 00000000000..db2cfe109fa --- /dev/null +++ b/tests/rest/synapse/mas/__init__.py @@ -0,0 +1,12 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . diff --git a/tests/rest/synapse/mas/_base.py b/tests/rest/synapse/mas/_base.py new file mode 100644 index 00000000000..19d33807a67 --- /dev/null +++ b/tests/rest/synapse/mas/_base.py @@ -0,0 +1,43 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.types import JsonDict + +from tests import unittest + + +class BaseTestCase(unittest.HomeserverTestCase): + SHARED_SECRET = "shared_secret" + + def default_config(self) -> JsonDict: + config = super().default_config() + config["enable_registration"] = False + config["experimental_features"] = { + "msc3861": { + "enabled": True, + "issuer": "https://example.com", + "client_id": "dummy", + "client_auth_method": "client_secret_basic", + "client_secret": "dummy", + "admin_token": self.SHARED_SECRET, + } + } + return config + + def create_resource_dict(self) -> dict[str, Resource]: + base = super().create_resource_dict() + base.update(build_synapse_client_resource_tree(self.hs)) + return base diff --git a/tests/rest/synapse/mas/test_devices.py b/tests/rest/synapse/mas/test_devices.py new file mode 100644 index 00000000000..a7cd58d8ff1 --- /dev/null +++ b/tests/rest/synapse/mas/test_devices.py @@ -0,0 +1,693 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.types import UserID +from synapse.util import Clock + +from tests.unittest import skip_unless +from tests.utils import HAS_AUTHLIB + +from ._base import BaseTestCase + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasUpsertDeviceResource(BaseTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Create a user for testing + self.alice_user_id = UserID("alice", "test") + self.get_success( + homeserver.get_registration_handler().register_user( + localpart=self.alice_user_id.localpart, + ) + ) + + def test_other_token(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/upsert_device", + shorthand=False, + access_token="other_token", + content={ + "localpart": "alice", + "device_id": "DEVICE1", + }, + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_upsert_device(self) -> None: + store = self.hs.get_datastores().main + + channel = self.make_request( + "POST", + "/_synapse/mas/upsert_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "device_id": "DEVICE1", + }, + ) + + # This created a new device, hence the 201 status code + self.assertEqual(channel.code, 201, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify the device exists + device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1")) + assert device is not None + self.assertEqual(device["device_id"], "DEVICE1") + self.assertIsNone(device["display_name"]) + + def test_update_existing_device(self) -> None: + store = self.hs.get_datastores().main + device_handler = self.hs.get_device_handler() + + # Create an initial device + self.get_success( + device_handler.upsert_device( + user_id=str(self.alice_user_id), + device_id="DEVICE1", + display_name="Old Name", + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/mas/upsert_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "device_id": "DEVICE1", + "display_name": "New Name", + }, + ) + + # This updated an existing device, hence the 200 status code + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify the device was updated + device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1")) + assert device is not None + self.assertEqual(device["display_name"], "New Name") + + def test_upsert_device_with_display_name(self) -> None: + store = self.hs.get_datastores().main + + channel = self.make_request( + "POST", + "/_synapse/mas/upsert_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "device_id": "DEVICE1", + "display_name": "Alice's Phone", + }, + ) + + self.assertEqual(channel.code, 201, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify the device exists with correct display name + device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1")) + assert device is not None + self.assertEqual(device["display_name"], "Alice's Phone") + + def test_upsert_device_missing_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/upsert_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "device_id": "DEVICE1", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_upsert_device_missing_device_id(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/upsert_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_upsert_device_nonexistent_user(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/upsert_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "nonexistent", + "device_id": "DEVICE1", + }, + ) + + # We get a 404 here as the user doesn't exist + self.assertEqual(channel.code, 404, channel.json_body) + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasDeleteDeviceResource(BaseTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Create a user and device for testing + self.alice_user_id = UserID("alice", "test") + self.get_success( + homeserver.get_registration_handler().register_user( + localpart=self.alice_user_id.localpart, + ) + ) + + # Create a device + device_handler = homeserver.get_device_handler() + self.get_success( + device_handler.upsert_device( + user_id=str(self.alice_user_id), + device_id="DEVICE1", + display_name="Test Device", + ) + ) + + def test_other_token(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_device", + shorthand=False, + access_token="other_token", + content={ + "localpart": "alice", + "device_id": "DEVICE1", + }, + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_delete_device(self) -> None: + store = self.hs.get_datastores().main + + # Verify device exists before deletion + device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1")) + assert device is not None + + channel = self.make_request( + "POST", + "/_synapse/mas/delete_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "device_id": "DEVICE1", + }, + ) + + self.assertEqual(channel.code, 204) + + # Verify the device no longer exists + device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1")) + self.assertIsNone(device) + + def test_delete_nonexistent_device(self) -> None: + # Deleting a non-existent device should be idempotent + channel = self.make_request( + "POST", + "/_synapse/mas/delete_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "device_id": "NONEXISTENT", + }, + ) + + self.assertEqual(channel.code, 204) + + def test_delete_device_missing_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "device_id": "DEVICE1", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_delete_device_missing_device_id(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_delete_device_nonexistent_user(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_device", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "nonexistent", + "device_id": "DEVICE1", + }, + ) + + # Should fail on a non-existent user + self.assertEqual(channel.code, 404, channel.json_body) + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasUpdateDeviceDisplayNameResource(BaseTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Create a user and device for testing + self.alice_user_id = UserID("alice", "test") + self.get_success( + homeserver.get_registration_handler().register_user( + localpart=self.alice_user_id.localpart, + ) + ) + + # Create a device + device_handler = homeserver.get_device_handler() + self.get_success( + device_handler.upsert_device( + user_id=str(self.alice_user_id), + device_id="DEVICE1", + display_name="Old Name", + ) + ) + + def test_other_token(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/update_device_display_name", + shorthand=False, + access_token="other_token", + content={ + "localpart": "alice", + "device_id": "DEVICE1", + "display_name": "New Name", + }, + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_update_device_display_name(self) -> None: + store = self.hs.get_datastores().main + + # Verify initial display name + device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1")) + assert device is not None + self.assertEqual(device["display_name"], "Old Name") + + channel = self.make_request( + "POST", + "/_synapse/mas/update_device_display_name", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "device_id": "DEVICE1", + "display_name": "Updated Name", + }, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify the display name was updated + device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1")) + assert device is not None + self.assertEqual(device["display_name"], "Updated Name") + + def test_update_nonexistent_device(self) -> None: + # Updating a non-existent device should fail + channel = self.make_request( + "POST", + "/_synapse/mas/update_device_display_name", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "device_id": "NONEXISTENT", + "display_name": "New Name", + }, + ) + + self.assertEqual(channel.code, 404, channel.json_body) + + def test_update_device_display_name_missing_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/update_device_display_name", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "device_id": "DEVICE1", + "display_name": "New Name", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_update_device_display_name_missing_device_id(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/update_device_display_name", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "display_name": "New Name", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_update_device_display_name_missing_display_name(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/update_device_display_name", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "device_id": "DEVICE1", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_update_device_display_name_nonexistent_user(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/update_device_display_name", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "nonexistent", + "device_id": "DEVICE1", + "display_name": "New Name", + }, + ) + + self.assertEqual(channel.code, 404, channel.json_body) + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasSyncDevicesResource(BaseTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Create a user for testing + self.alice_user_id = UserID("alice", "test") + self.get_success( + homeserver.get_registration_handler().register_user( + localpart=self.alice_user_id.localpart, + ) + ) + + # Create some initial devices + device_handler = homeserver.get_device_handler() + for device_id in ["DEVICE1", "DEVICE2", "DEVICE3"]: + self.get_success( + device_handler.upsert_device( + user_id=str(self.alice_user_id), + device_id=device_id, + display_name=f"Device {device_id}", + ) + ) + + def test_other_token(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token="other_token", + content={ + "localpart": "alice", + "devices": ["DEVICE1", "DEVICE2"], + }, + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_sync_devices_no_changes(self) -> None: + # Sync with the same devices that already exist + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "devices": ["DEVICE1", "DEVICE2", "DEVICE3"], + }, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify all devices still exist + store = self.hs.get_datastores().main + devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id))) + self.assertEqual(set(devices.keys()), {"DEVICE1", "DEVICE2", "DEVICE3"}) + + def test_sync_devices_add_only(self) -> None: + # Sync with additional devices + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "devices": ["DEVICE1", "DEVICE2", "DEVICE3", "DEVICE4", "DEVICE5"], + }, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify new devices were added + store = self.hs.get_datastores().main + devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id))) + self.assertEqual( + set(devices.keys()), {"DEVICE1", "DEVICE2", "DEVICE3", "DEVICE4", "DEVICE5"} + ) + + def test_sync_devices_delete_only(self) -> None: + # Sync with fewer devices + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "devices": ["DEVICE1"], + }, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify devices were deleted + store = self.hs.get_datastores().main + devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id))) + self.assertEqual(set(devices.keys()), {"DEVICE1"}) + + def test_sync_devices_add_and_delete(self) -> None: + # Sync with a mix of additions and deletions + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "devices": ["DEVICE1", "DEVICE4", "DEVICE5"], + }, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify the correct devices exist + store = self.hs.get_datastores().main + devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id))) + self.assertEqual(set(devices.keys()), {"DEVICE1", "DEVICE4", "DEVICE5"}) + + def test_sync_devices_empty_list(self) -> None: + # Sync with empty device list (delete all devices) + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "devices": [], + }, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify all devices were deleted + store = self.hs.get_datastores().main + devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id))) + self.assertEqual(devices, {}) + + def test_sync_devices_for_new_user(self) -> None: + # Test syncing devices for a user that doesn't have any devices yet + bob_user_id = UserID("bob", "test") + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=bob_user_id.localpart, + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "bob", + "devices": ["DEVICE1", "DEVICE2"], + }, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify devices were created + store = self.hs.get_datastores().main + devices = self.get_success(store.get_devices_by_user(str(bob_user_id))) + self.assertEqual(set(devices.keys()), {"DEVICE1", "DEVICE2"}) + + def test_sync_devices_missing_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "devices": ["DEVICE1", "DEVICE2"], + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_sync_devices_missing_devices(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_sync_devices_invalid_devices_type(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "devices": "not_a_list", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_sync_devices_nonexistent_user(self) -> None: + # Test syncing devices for a user that doesn't exist + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "nonexistent", + "devices": ["DEVICE1", "DEVICE2"], + }, + ) + + self.assertEqual(channel.code, 404, channel.json_body) + + def test_sync_devices_duplicate_device_ids(self) -> None: + # Test syncing with duplicate device IDs (sets should handle this) + channel = self.make_request( + "POST", + "/_synapse/mas/sync_devices", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "devices": ["DEVICE1", "DEVICE1", "DEVICE2"], + }, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Verify the correct devices exist (duplicates should be handled) + store = self.hs.get_datastores().main + devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id))) + self.assertEqual(sorted(devices.keys()), ["DEVICE1", "DEVICE2"]) diff --git a/tests/rest/synapse/mas/test_users.py b/tests/rest/synapse/mas/test_users.py new file mode 100644 index 00000000000..378f29fd4ce --- /dev/null +++ b/tests/rest/synapse/mas/test_users.py @@ -0,0 +1,1399 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from urllib.parse import urlencode + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.appservice import ApplicationService +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, create_requester +from synapse.util import Clock + +from tests.unittest import skip_unless +from tests.utils import HAS_AUTHLIB + +from ._base import BaseTestCase + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasQueryUserResource(BaseTestCase): + def test_other_token(self) -> None: + channel = self.make_request( + "GET", + "/_synapse/mas/query_user?localpart=alice", + shorthand=False, + access_token="other_token", + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_query_user(self) -> None: + alice = UserID("alice", "test") + store = self.hs.get_datastores().main + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=alice.localpart, + default_display_name="Alice", + ) + ) + self.get_success( + store.set_profile_avatar_url( + user_id=alice, + new_avatar_url="mxc://example.com/avatar", + ) + ) + + channel = self.make_request( + "GET", + "/_synapse/mas/query_user?localpart=alice", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual( + channel.json_body, + { + "user_id": "@alice:test", + "display_name": "Alice", + "avatar_url": "mxc://example.com/avatar", + "is_suspended": False, + "is_deactivated": False, + }, + ) + + self.get_success( + store.set_user_suspended_status(user_id=str(alice), suspended=True) + ) + + channel = self.make_request( + "GET", + "/_synapse/mas/query_user?localpart=alice", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual( + channel.json_body, + { + "user_id": "@alice:test", + "display_name": "Alice", + "avatar_url": "mxc://example.com/avatar", + "is_suspended": True, + "is_deactivated": False, + }, + ) + + # Deactivate the account, it should clear the display name and avatar + # and mark the user as deactivated + self.get_success( + self.hs.get_deactivate_account_handler().deactivate_account( + user_id=str(alice), + erase_data=True, + requester=create_requester(alice), + ) + ) + + channel = self.make_request( + "GET", + "/_synapse/mas/query_user?localpart=alice", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual( + channel.json_body, + { + "user_id": "@alice:test", + "display_name": None, + "avatar_url": None, + "is_suspended": True, + "is_deactivated": True, + }, + ) + + def test_query_unknown_user(self) -> None: + channel = self.make_request( + "GET", + "/_synapse/mas/query_user?localpart=alice", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 404, channel.json_body) + + def test_query_user_missing_localpart(self) -> None: + channel = self.make_request( + "GET", + "/_synapse/mas/query_user", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasProvisionUserResource(BaseTestCase): + def test_other_token(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token="other_token", + content={"localpart": "alice"}, + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_provision_user(self) -> None: + store = self.hs.get_datastores().main + + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "set_displayname": "Alice", + "set_emails": ["alice@example.com"], + "set_avatar_url": "mxc://example.com/avatar", + }, + ) + + # This created the user, hence the 201 status code + self.assertEqual(channel.code, 201, channel.json_body) + self.assertEqual(channel.json_body, {}) + + alice = UserID("alice", "test") + profile = self.get_success(store.get_profileinfo(alice)) + self.assertEqual(profile.display_name, "Alice") + self.assertEqual(profile.avatar_url, "mxc://example.com/avatar") + threepids = self.get_success(store.user_get_threepids(str(alice))) + self.assertEqual(len(threepids), 1) + self.assertEqual(threepids[0].medium, "email") + self.assertEqual(threepids[0].address, "alice@example.com") + + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "unset_displayname": True, + "unset_avatar_url": True, + "unset_emails": True, + }, + ) + + # This updated the user, hence the 200 status code + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Check that the profile and threepids were deleted + profile = self.get_success(store.get_profileinfo(alice)) + self.assertEqual(profile.display_name, None) + self.assertEqual(profile.avatar_url, None) + threepids = self.get_success(store.user_get_threepids(str(alice))) + self.assertEqual(threepids, []) + + def test_provision_user_missing_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "set_displayname": "Alice", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_provision_user_empty_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "", + "set_displayname": "Alice", + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_provision_user_invalid_localpart(self) -> None: + # Test with characters that are invalid in localparts + invalid_localparts = [ + "@alice:test", # That's a MXID + "alice@domain.com", + "alice:test", + "alice space", + "alice#hash", + "a" * 1000, # Very long localpart + ] + + for localpart in invalid_localparts: + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": localpart, + "set_displayname": "Alice", + }, + ) + # Should be a validation error + self.assertEqual( + channel.code, 400, f"Should fail for localpart: {localpart}" + ) + + def test_provision_user_multiple_emails(self) -> None: + store = self.hs.get_datastores().main + + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "set_emails": ["alice@example.com", "alice.alt@example.com"], + }, + ) + + self.assertEqual(channel.code, 201, channel.json_body) + + alice = UserID("alice", "test") + threepids = self.get_success(store.user_get_threepids(str(alice))) + self.assertEqual(len(threepids), 2) + email_addresses = {tp.address for tp in threepids} + self.assertEqual( + email_addresses, {"alice@example.com", "alice.alt@example.com"} + ) + + def test_provision_user_duplicate_emails(self) -> None: + store = self.hs.get_datastores().main + + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "set_emails": ["alice@example.com", "alice@example.com"], + }, + ) + + self.assertEqual(channel.code, 201, channel.json_body) + + alice = UserID("alice", "test") + threepids = self.get_success(store.user_get_threepids(str(alice))) + # Should deduplicate + self.assertEqual(len(threepids), 1) + self.assertEqual(threepids[0].address, "alice@example.com") + + def test_provision_user_conflicting_operations(self) -> None: + # Test setting and unsetting the same field + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "set_displayname": "Alice", + "unset_displayname": True, + }, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_provision_user_invalid_json_types(self) -> None: + # Test with wrong data types + invalid_contents: list[JsonDict] = [ + {"localpart": "alice", "set_displayname": 123}, # Number instead of string + { + "localpart": "alice", + "set_emails": "not-an-array", + }, # String instead of array + { + "localpart": "alice", + "unset_displayname": "not-a-bool", + }, # String instead of bool + {"localpart": 123}, # Number instead of string for localpart + ] + + for content in invalid_contents: + channel = self.make_request( + "POST", + "/_synapse/mas/provision_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content=content, + ) + self.assertEqual(channel.code, 400, f"Should fail for content: {content}") + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasIsLocalpartAvailableResource(BaseTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Provision a user + store = homeserver.get_datastores().main + self.get_success(store.register_user("@alice:test")) + + def test_other_token(self) -> None: + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=alice", + shorthand=False, + access_token="other_token", + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_is_localpart_available(self) -> None: + # "alice" is not available + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=alice", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") + + # "bob" is available + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=bob", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + def test_is_localpart_available_invalid_localparts(self) -> None: + # Numeric-only localparts are not allowed + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=0", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_INVALID_USERNAME") + + # A super-long MXID is not allowed by the spec + super_long = "a" * 1000 + channel = self.make_request( + "GET", + f"/_synapse/mas/is_localpart_available?localpart={super_long}", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_INVALID_USERNAME") + + def test_is_localpart_available_appservice_exclusive(self) -> None: + # Insert an appservice which has exclusive namespaces + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@as_user_.*:.+", "exclusive": True}]}, + sender=UserID.from_string("@as_main:test"), + ) + self.hs.get_datastores().main.services_cache = [appservice] + + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=as_main", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_EXCLUSIVE") + + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=as_user_alice", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_EXCLUSIVE") + + # Sanity-check that "bob" is available + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=bob", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + def test_is_localpart_available_missing_localpart(self) -> None: + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_is_localpart_available_empty_localpart(self) -> None: + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_is_localpart_available_invalid_characters(self) -> None: + # Test with characters that are invalid in localparts + invalid_localparts = [ + "alice@domain.com", # Contains @ + "alice:test", # Contains : + "alice space", # Contains space + "alice\\backslash", # Contains backslash + "alice#hash", # Contains hash + "alice$dollar", # Contains $ + "alice%percent", # Contains % + "alice&", # Contains & + "alice?question", # Contains ? + "alice[bracket", # Contains [ + "alice]bracket", # Contains ] + "alice{brace", # Contains { + "alice}brace", # Contains } + "alice|pipe", # Contains | + 'alice"quote', # Contains " + "alice'apostrophe", # Contains ' + "alicegreater", # Contains > + "alice\ttab", # Contains tab + "alice\nnewline", # Contains newline + ] + + for localpart in invalid_localparts: + channel = self.make_request( + "GET", + f"/_synapse/mas/is_localpart_available?{urlencode({'localpart': localpart})}", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + # Should return 400 for invalid characters + self.assertEqual( + channel.code, + 400, + f"Should reject localpart with invalid chars: {localpart}", + ) + self.assertEqual( + channel.json_body["errcode"], "M_INVALID_USERNAME", localpart + ) + + def test_is_localpart_available_case_sensitivity(self) -> None: + # Register a user with an uppercase localpart + self.get_success(self.hs.get_datastores().main.register_user("@BOB:test")) + + # It should report as not available, the search should be case-insensitive + channel = self.make_request( + "GET", + "/_synapse/mas/is_localpart_available?localpart=bob", + shorthand=False, + access_token=self.SHARED_SECRET, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasDeleteUserResource(BaseTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Provision a user with a display name + self.get_success( + homeserver.get_registration_handler().register_user( + localpart="alice", + default_display_name="Alice", + ) + ) + + def test_other_token(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token="other_token", + content={"localpart": "alice", "erase": False}, + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_delete_user_no_erase(self) -> None: + alice = UserID("alice", "test") + store = self.hs.get_datastores().main + + # Delete the user + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "erase": False}, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Check that the user was deleted + self.assertTrue( + self.get_success(store.get_user_deactivated_status(user_id=str(alice))) + ) + # But not erased + self.assertFalse(self.get_success(store.is_user_erased(user_id=str(alice)))) + + def test_delete_user_erase(self) -> None: + alice = UserID("alice", "test") + store = self.hs.get_datastores().main + + # Delete the user + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "erase": True}, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Check that the user was deleted + self.assertTrue( + self.get_success(store.get_user_deactivated_status(user_id=str(alice))) + ) + # And erased + self.assertTrue(self.get_success(store.is_user_erased(user_id=str(alice)))) + + def test_delete_user_missing_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"erase": False}, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_delete_user_missing_erase(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice"}, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_delete_user_invalid_erase_type(self) -> None: + invalid_erase_values = [ + "true", # String instead of bool + 1, # Number instead of bool + "false", # String instead of bool + 0, # Number instead of bool + {}, # Object instead of bool + [], # Array instead of bool + ] + + for erase_value in invalid_erase_values: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "erase": erase_value}, + ) + self.assertEqual( + channel.code, 400, f"Should fail for erase value: {erase_value}" + ) + + def test_delete_nonexistent_user(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "nonexistent", "erase": False}, + ) + + self.assertEqual(channel.code, 404) + + def test_delete_already_deleted_user(self) -> None: + # First deletion + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "erase": False}, + ) + self.assertEqual(channel.code, 200) + + # Second deletion should be idempotent + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "erase": False}, + ) + self.assertEqual(channel.code, 200) + + def test_delete_user_erase_already_deleted_user(self) -> None: + alice = UserID("alice", "test") + store = self.hs.get_datastores().main + + # First delete without erase + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "erase": False}, + ) + self.assertEqual(channel.code, 200) + + # Verify not erased initially + self.assertFalse(self.get_success(store.is_user_erased(user_id=str(alice)))) + + # Now delete with erase + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "erase": True}, + ) + self.assertEqual(channel.code, 200) + + # Should now be erased + self.assertTrue(self.get_success(store.is_user_erased(user_id=str(alice)))) + + def test_delete_user_empty_json(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={}, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_delete_user_extra_fields(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/delete_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "erase": False, + "extra_field": "should_be_ignored", + "another_field": 123, + }, + ) + + # Should succeed and ignore extra fields + self.assertEqual(channel.code, 200, channel.json_body) + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasReactivateUserResource(BaseTestCase): + def test_other_token(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/reactivate_user", + shorthand=False, + access_token="other_token", + content={"localpart": "alice"}, + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_reactivate_user(self) -> None: + alice = UserID("alice", "test") + store = self.hs.get_datastores().main + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=alice.localpart, + default_display_name="Alice", + ) + ) + self.get_success( + self.hs.get_deactivate_account_handler().deactivate_account( + user_id=str(alice), + erase_data=True, + requester=create_requester(alice), + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/mas/reactivate_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice"}, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Check that the user was reactivated + self.assertFalse( + self.get_success(store.get_user_deactivated_status(user_id=str(alice))) + ) + + def test_reactivate_user_missing_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/reactivate_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={}, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_reactivate_nonexistent_user(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/reactivate_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "nonexistent"}, + ) + + self.assertEqual(channel.code, 404, channel.json_body) + + def test_reactivate_active_user(self) -> None: + # Create an active user + alice = UserID("alice", "test") + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=alice.localpart, + default_display_name="Alice", + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/mas/reactivate_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice"}, + ) + + # Should be idempotent + self.assertEqual(channel.code, 200, channel.json_body) + + def test_reactivate_erased_user(self) -> None: + alice = UserID("alice", "test") + store = self.hs.get_datastores().main + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=alice.localpart, + default_display_name="Alice", + ) + ) + + # Deactivate with erase + self.get_success( + self.hs.get_deactivate_account_handler().deactivate_account( + user_id=str(alice), + erase_data=True, + requester=create_requester(alice), + ) + ) + + # Verify user is erased + self.assertTrue(self.get_success(store.is_user_erased(user_id=str(alice)))) + + channel = self.make_request( + "POST", + "/_synapse/mas/reactivate_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice"}, + ) + + # Should succeed even for erased users + self.assertEqual(channel.code, 200, channel.json_body) + # Shouldn't be erased anymore + self.assertFalse(self.get_success(store.is_user_erased(user_id=str(alice)))) + + def test_reactivate_user_extra_fields(self) -> None: + alice = UserID("alice", "test") + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=alice.localpart, + ) + ) + self.get_success( + self.hs.get_deactivate_account_handler().deactivate_account( + user_id=str(alice), + erase_data=False, + requester=create_requester(alice), + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/mas/reactivate_user", + shorthand=False, + access_token=self.SHARED_SECRET, + content={ + "localpart": "alice", + "extra_field": "should_be_ignored", + "another_field": 123, + }, + ) + + # Should succeed and ignore extra fields + self.assertEqual(channel.code, 200, channel.json_body) + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MasSetDisplayNameResource(BaseTestCase): + def test_other_token(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/set_displayname", + shorthand=False, + access_token="other_token", + content={"localpart": "alice", "displayname": "Bob"}, + ) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["error"], "This endpoint must only be called by MAS" + ) + + def test_set_display_name(self) -> None: + alice = UserID("alice", "test") + store = self.hs.get_datastores().main + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=alice.localpart, + default_display_name="Alice", + ) + ) + profile = self.get_success(store.get_profileinfo(alice)) + self.assertEqual(profile.display_name, "Alice") + + channel = self.make_request( + "POST", + "/_synapse/mas/set_displayname", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "displayname": "Bob"}, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.json_body, {}) + + # Check that the profile was updated + profile = self.get_success(store.get_profileinfo(alice)) + self.assertEqual(profile.display_name, "Bob") + + def test_set_display_name_missing_localpart(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/set_displayname", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"displayname": "Bob"}, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_set_display_name_missing_displayname(self) -> None: + channel = self.make_request( + "POST", + "/_synapse/mas/set_displayname", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice"}, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_set_display_name_very_long(self) -> None: + alice = UserID("alice", "test") + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=alice.localpart, + ) + ) + + long_name = "A" * 1000 + channel = self.make_request( + "POST", + "/_synapse/mas/set_displayname", + shorthand=False, + access_token=self.SHARED_SECRET, + content={"localpart": "alice", "displayname": long_name}, + ) + + self.assertEqual(channel.code, 400, channel.json_body) + + def test_set_display_name_special_characters(self) -> None: + alice = UserID("alice", "test") + self.get_success( + self.hs.get_registration_handler().register_user( + localpart=alice.localpart, + ) + ) + + special_names = [ + "Alice 👋", # Emoji + "Alice & Bob", # HTML entities + "Alice\nNewline", # Newline + "Alice\tTab", # Tab + 'Alice"Quote', # Quote + "Alice'Apostrophe", # Apostrophe + "Alice