|
6 | 6 |
|
7 | 7 | from collections import OrderedDict |
8 | 8 | from pydantic import BaseModel |
9 | | -from typing import Callable, Optional, Dict, Union, TypeVar, Type |
| 9 | +from typing import Callable, Optional, Dict, Generic, Union, TypeVar, Type |
10 | 10 | from types import SimpleNamespace |
11 | 11 | from urllib.parse import urlencode |
12 | 12 |
|
@@ -193,29 +193,35 @@ async def receive_loop(self, connection: WebSocketConnection): |
193 | 193 | raise ValueError(f"Error received from server: {data['error']}") |
194 | 194 |
|
195 | 195 | stream = data.get("stream") |
196 | | - callbacks = ( |
197 | | - connection.stream_callback_map.get(stream) if stream else None |
198 | | - ) |
199 | | - response_model = ( |
200 | | - connection.response_types.get(stream) if stream else None |
201 | | - ) |
| 196 | + subscription_id = data.get("subscriptionId") |
| 197 | + |
| 198 | + key = stream or subscription_id |
| 199 | + callbacks = connection.stream_callback_map.get(key) if key is not None else None |
| 200 | + |
202 | 201 | if callbacks: |
203 | 202 | try: |
204 | | - for callback in callbacks: |
205 | | - if response_model is None: |
206 | | - callback(data) |
207 | | - else: |
208 | | - data = data["data"] |
209 | | - if isinstance(data, list): |
210 | | - callback( |
211 | | - [response_model.model_validate_json(json.dumps(item)) for item in data] |
212 | | - ) |
| 203 | + if stream: |
| 204 | + response_model = connection.response_types.get(stream) |
| 205 | + payload = data["data"] if response_model else data |
| 206 | + |
| 207 | + for callback in callbacks: |
| 208 | + if response_model: |
| 209 | + if isinstance(payload, list): |
| 210 | + parsed = [ |
| 211 | + response_model.model_validate_json(json.dumps(item)) |
| 212 | + for item in payload |
| 213 | + ] |
| 214 | + callback(parsed) |
| 215 | + else: |
| 216 | + callback(response_model.model_validate_json(json.dumps(payload))) |
213 | 217 | else: |
214 | | - callback(response_model.model_validate_json(json.dumps(data))) |
| 218 | + callback(payload) |
| 219 | + else: |
| 220 | + payload = data["event"] |
| 221 | + for callback in callbacks: |
| 222 | + callback(payload) |
215 | 223 | except Exception as e: |
216 | | - raise ValueError( |
217 | | - f"Error in callback for stream {stream}: {e}" |
218 | | - ) |
| 224 | + raise ValueError(f"Error in callback for key {key}: {e}") |
219 | 225 | else: |
220 | 226 | logging.info(f"Received message: {data}") |
221 | 227 | elif msg.type == aiohttp.WSMsgType.PING: |
@@ -723,22 +729,111 @@ async def ping_ws_api(self, connection: WebSocketConnection): |
723 | 729 |
|
724 | 730 | await super().ping(connection) |
725 | 731 |
|
| 732 | + async def subscribe_user_data(self, id: str): |
| 733 | + if self.configuration.mode == WebsocketMode.SINGLE: |
| 734 | + connection = self.connections[0] |
| 735 | + else: |
| 736 | + connection = self.connections[ |
| 737 | + self.round_robin_index % len(self.connections) |
| 738 | + ] |
| 739 | + self.round_robin_index = (self.round_robin_index + 1) % len( |
| 740 | + self.connections |
| 741 | + ) |
| 742 | + global_stream_connections.stream_connections_map[id] = connection |
| 743 | + connection.stream_callback_map.update({id: []}) |
| 744 | + |
| 745 | + def on(self, event: str, callback: Callable[[T], None], id: str) -> None: |
| 746 | + """Set the callback function for incoming messages on a specific ID. |
| 747 | +
|
| 748 | + Args: |
| 749 | + event (str): Event type. |
| 750 | + callback (Callable): Callback function. |
| 751 | + id (str): User Data ID. |
| 752 | + """ |
| 753 | + |
| 754 | + if event != "message": |
| 755 | + raise ValueError(f"Unsupported event: {event}") |
| 756 | + |
| 757 | + connection = ( |
| 758 | + global_stream_connections.stream_connections_map[id] |
| 759 | + if id in global_stream_connections.stream_connections_map |
| 760 | + else None |
| 761 | + ) |
| 762 | + |
| 763 | + if connection: |
| 764 | + connection.stream_callback_map[id].append(callback) |
| 765 | + else: |
| 766 | + logging.warning(f"Stream {id} not connected.") |
| 767 | + |
| 768 | + async def unsubscribe(self, id: str): |
| 769 | + """Unsubscribe from a user data ID. |
| 770 | +
|
| 771 | + Args: |
| 772 | + id (str): user data ID to unsubscribe from. |
| 773 | + """ |
| 774 | + |
| 775 | + if self.connections is None or len(self.connections) == 0: |
| 776 | + logging.warning("No user data connections available for unsubscription.") |
| 777 | + return |
| 778 | + |
| 779 | + if id not in global_stream_connections.stream_connections_map: |
| 780 | + logging.warning(f"Stream {id} is not subscribed.") |
| 781 | + return |
| 782 | + |
| 783 | + connection = ( |
| 784 | + global_stream_connections.stream_connections_map[id] |
| 785 | + if id in global_stream_connections.stream_connections_map |
| 786 | + else None |
| 787 | + ) |
| 788 | + if connection: |
| 789 | + global_stream_connections.stream_connections_map.pop(id, None) |
| 790 | + logging.info(f"Unsubscribed from stream: {id}") |
| 791 | + else: |
| 792 | + raise ValueError(f"Subscription id {id} not connected.") |
| 793 | + |
| 794 | + |
| 795 | +class RequestStreamHandle(Generic[T]): |
| 796 | + """A wrapper for Request Stream Method. |
| 797 | +
|
| 798 | + :param websocket_base: WebSocket base. |
| 799 | + :param stream: Stream name. |
| 800 | + :param response_model: The Pydantic model to validate the response data. |
| 801 | + """ |
| 802 | + |
| 803 | + def __init__( |
| 804 | + self, |
| 805 | + websocket_base: WebSocketStreamBase | WebSocketAPIBase, |
| 806 | + stream: str, |
| 807 | + response_model: Type[T] = None, |
| 808 | + ): |
| 809 | + self._websocket_base = websocket_base |
| 810 | + self._stream = stream |
| 811 | + self._response_model = response_model |
| 812 | + |
| 813 | + async def unsubscribe(self) -> None: |
| 814 | + if isinstance(self._websocket_base, WebSocketStreamBase): |
| 815 | + await self._websocket_base.unsubscribe(streams=self._stream) |
| 816 | + else: |
| 817 | + await self._websocket_base.unsubscribe(id=self._stream) |
| 818 | + |
| 819 | + def on(self, event: str, callback: Callable[[T], None]) -> None: |
| 820 | + self._websocket_base.on(event, callback, self._stream) |
| 821 | + |
726 | 822 |
|
727 | 823 | async def RequestStream( |
728 | | - websocket_base: WebSocketStreamBase, stream: str, response_model: Type[T] = None |
729 | | -) -> SimpleNamespace: |
| 824 | + websocket_base: WebSocketStreamBase | WebSocketAPIBase, stream: str, response_model: Type[T] = None |
| 825 | +) -> RequestStreamHandle[T]: |
730 | 826 | """Decorator to create a request stream for a specific stream. |
731 | 827 |
|
732 | 828 | Args: |
733 | | - websocket_base (WebSocketStreamBase): WebSocket stream base. |
| 829 | + websocket_base (WebSocketStreamBase | WebSocketAPIBase): WebSocket base. |
734 | 830 | stream (str): Stream name. |
| 831 | + response_model (Type[T], optional): Response model for the stream. |
735 | 832 | """ |
736 | | - await websocket_base.subscribe(streams=[stream], response_model=response_model) |
737 | | - |
738 | | - def on(event: str, callback: Callable[[T], None]): |
739 | | - websocket_base.on(event, callback, stream) |
740 | 833 |
|
741 | | - async def unsubscribe(): |
742 | | - await websocket_base.unsubscribe(streams=stream) |
| 834 | + if isinstance(websocket_base, WebSocketStreamBase): |
| 835 | + await websocket_base.subscribe(streams=[stream], response_model=response_model) |
| 836 | + else: |
| 837 | + await websocket_base.subscribe_user_data(id=stream) |
743 | 838 |
|
744 | | - return SimpleNamespace(on=on, unsubscribe=unsubscribe) |
| 839 | + return RequestStreamHandle(websocket_base, stream, response_model) |
0 commit comments