55# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66from __future__ import annotations
77
8- from typing import AsyncGenerator , ClassVar , Literal , Mapping , Union
8+ from typing import ClassVar , Literal , Mapping
99from enum import Enum
1010from json .decoder import JSONDecodeError
1111from urllib .parse import quote as urllib_quote , urljoin as urllib_join
2828
2929if __optional_imports__ :
3030 # Safe to import, but it's not actually needed, so don't force-import the whole types module.
31- from mautrix .types import JSON
31+ from mautrix .types import JSON , DeviceID , UserID
3232
3333API_CALLS = Counter (
3434 name = "bridge_matrix_api_calls" ,
@@ -193,6 +193,13 @@ class HTTPAPI:
193193 default_retry_count : int
194194 """The default retry count to use if a custom value is not passed to :meth:`request`"""
195195
196+ as_user_id : UserID | None
197+ """An optional user ID to set as the user_id query parameter for appservice requests."""
198+ as_device_id : DeviceID | None
199+ """
200+ An optional device ID to set as the user_id query parameter for appservice requests (MSC3202).
201+ """
202+
196203 def __init__ (
197204 self ,
198205 base_url : URL | str ,
@@ -203,6 +210,8 @@ def __init__(
203210 txn_id : int = 0 ,
204211 log : TraceLogger | None = None ,
205212 loop : asyncio .AbstractEventLoop | None = None ,
213+ as_user_id : UserID | None = None ,
214+ as_device_id : UserID | None = None ,
206215 ) -> None :
207216 """
208217 Args:
@@ -212,13 +221,19 @@ def __init__(
212221 txn_id: The outgoing transaction ID to start with.
213222 log: The :class:`logging.Logger` instance to log requests with.
214223 default_retry_count: Default number of retries to do when encountering network errors.
224+ as_user_id: An optional user ID to set as the user_id query parameter for
225+ appservice requests.
226+ as_device_id: An optional device ID to set as the user_id query parameter for
227+ appservice requests (MSC3202).
215228 """
216229 self .base_url = URL (base_url )
217230 self .token = token
218231 self .log = log or logging .getLogger ("mau.http" )
219232 self .session = client_session or ClientSession (
220233 loop = loop , headers = {"User-Agent" : self .default_ua }
221234 )
235+ self .as_user_id = as_user_id
236+ self .as_device_id = as_device_id
222237 if txn_id is not None :
223238 self .txn_id = txn_id
224239 if default_retry_count is not None :
@@ -360,6 +375,11 @@ async def request(
360375 query_params = query_params or {}
361376 if isinstance (query_params , dict ):
362377 query_params = {k : v for k , v in query_params .items () if v is not None }
378+ if self .as_user_id :
379+ query_params ["user_id" ] = self .as_user_id
380+ if self .as_device_id :
381+ query_params ["org.matrix.msc3202.device_id" ] = self .as_device_id
382+ query_params ["device_id" ] = self .as_device_id
363383
364384 if method != Method .GET :
365385 content = content or {}
0 commit comments