|
11 | 11 | class BaseRpcApi: |
12 | 12 | """Base class for all RPC API subclients.""" |
13 | 13 |
|
14 | | - def __init__(self, rpc_client: RabbitMQRPCClient, namespace: RPCNamespace): |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + rpc_client: RabbitMQRPCClient, |
| 17 | + namespace: RPCNamespace, |
| 18 | + rpc_request_kwargs: dict[str, Any] | None = None, |
| 19 | + ): |
15 | 20 | self._rpc_client = rpc_client |
16 | 21 | self._namespace = namespace |
| 22 | + self._rpc_request_kwargs = rpc_request_kwargs or {} |
17 | 23 |
|
18 | 24 | async def _request( |
19 | 25 | self, |
20 | 26 | method_name: RPCMethodName, |
21 | 27 | *, |
22 | 28 | product_name: ProductName, |
23 | 29 | user_id: UserID, |
24 | | - **kwargs: Any |
| 30 | + **optional_kwargs: Any |
25 | 31 | ) -> Any: |
26 | | - return await self._rpc_client.request( |
27 | | - self._namespace, |
| 32 | + assert self._rpc_request_kwargs.keys().isdisjoint(optional_kwargs.keys()), ( |
| 33 | + "Conflict between request extras and kwargs" |
| 34 | + "Please rename the conflicting keys." |
| 35 | + ) |
| 36 | + |
| 37 | + return await self._request_without_authentication( |
28 | 38 | method_name, |
29 | 39 | product_name=product_name, |
30 | 40 | user_id=user_id, |
31 | | - **kwargs |
| 41 | + **optional_kwargs, |
| 42 | + **self._rpc_request_kwargs, |
32 | 43 | ) |
33 | 44 |
|
34 | 45 | async def _request_without_authentication( |
35 | 46 | self, method_name: RPCMethodName, *, product_name: ProductName, **kwargs: Any |
36 | 47 | ) -> Any: |
| 48 | + assert self._rpc_request_kwargs.keys().isdisjoint(kwargs.keys()), ( |
| 49 | + "Conflict between request extras and kwargs" |
| 50 | + "Please rename the conflicting keys." |
| 51 | + ) |
| 52 | + |
37 | 53 | return await self._rpc_client.request( |
38 | | - self._namespace, method_name, product_name=product_name, **kwargs |
| 54 | + self._namespace, |
| 55 | + method_name, |
| 56 | + product_name=product_name, |
| 57 | + **kwargs, |
| 58 | + **self._rpc_request_kwargs, |
39 | 59 | ) |
0 commit comments