Skip to content

Commit 8ff8d07

Browse files
committed
Add appservice user/device masquerading support to base HTTPAPI
1 parent 92e6091 commit 8ff8d07

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

mautrix/api.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import AsyncGenerator, ClassVar, Literal, Mapping, Union
8+
from typing import ClassVar, Literal, Mapping
99
from enum import Enum
1010
from json.decoder import JSONDecodeError
1111
from urllib.parse import quote as urllib_quote, urljoin as urllib_join
@@ -28,7 +28,7 @@
2828

2929
if __optional_imports__:
3030
# Safe to import, but it's not actually needed, so don't force-import the whole types module.
31-
from mautrix.types import JSON
31+
from mautrix.types import JSON, DeviceID, UserID
3232

3333
API_CALLS = Counter(
3434
name="bridge_matrix_api_calls",
@@ -193,6 +193,13 @@ class HTTPAPI:
193193
default_retry_count: int
194194
"""The default retry count to use if a custom value is not passed to :meth:`request`"""
195195

196+
as_user_id: UserID | None
197+
"""An optional user ID to set as the user_id query parameter for appservice requests."""
198+
as_device_id: DeviceID | None
199+
"""
200+
An optional device ID to set as the user_id query parameter for appservice requests (MSC3202).
201+
"""
202+
196203
def __init__(
197204
self,
198205
base_url: URL | str,
@@ -203,6 +210,8 @@ def __init__(
203210
txn_id: int = 0,
204211
log: TraceLogger | None = None,
205212
loop: asyncio.AbstractEventLoop | None = None,
213+
as_user_id: UserID | None = None,
214+
as_device_id: UserID | None = None,
206215
) -> None:
207216
"""
208217
Args:
@@ -212,13 +221,19 @@ def __init__(
212221
txn_id: The outgoing transaction ID to start with.
213222
log: The :class:`logging.Logger` instance to log requests with.
214223
default_retry_count: Default number of retries to do when encountering network errors.
224+
as_user_id: An optional user ID to set as the user_id query parameter for
225+
appservice requests.
226+
as_device_id: An optional device ID to set as the user_id query parameter for
227+
appservice requests (MSC3202).
215228
"""
216229
self.base_url = URL(base_url)
217230
self.token = token
218231
self.log = log or logging.getLogger("mau.http")
219232
self.session = client_session or ClientSession(
220233
loop=loop, headers={"User-Agent": self.default_ua}
221234
)
235+
self.as_user_id = as_user_id
236+
self.as_device_id = as_device_id
222237
if txn_id is not None:
223238
self.txn_id = txn_id
224239
if default_retry_count is not None:
@@ -360,6 +375,11 @@ async def request(
360375
query_params = query_params or {}
361376
if isinstance(query_params, dict):
362377
query_params = {k: v for k, v in query_params.items() if v is not None}
378+
if self.as_user_id:
379+
query_params["user_id"] = self.as_user_id
380+
if self.as_device_id:
381+
query_params["org.matrix.msc3202.device_id"] = self.as_device_id
382+
query_params["device_id"] = self.as_device_id
363383

364384
if method != Method.GET:
365385
content = content or {}

0 commit comments

Comments
 (0)