|
4 | 4 | from contextlib import asynccontextmanager, contextmanager |
5 | 5 |
|
6 | 6 | import httpx |
7 | | -import websockets.exceptions |
8 | 7 | import websockets.sync.client as websockets_sync_client |
9 | 8 | from ...core.api_error import ApiError |
10 | 9 | from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper |
11 | 10 | from ...core.request_options import RequestOptions |
| 11 | +from ...core.websocket_compat import InvalidWebSocketStatus, get_status_code |
12 | 12 | from ...core.serialization import convert_and_respect_annotation_metadata |
13 | 13 | from ...core.query_encoder import single_query_encoder |
14 | 14 | from ..types.connect_session_settings import ConnectSessionSettings |
@@ -141,8 +141,8 @@ def connect( |
141 | 141 | try: |
142 | 142 | with websockets_sync_client.connect(ws_url, additional_headers=headers) as protocol: |
143 | 143 | yield ChatSocketClient(websocket=protocol) |
144 | | - except websockets.exceptions.InvalidStatusCode as exc: |
145 | | - status_code: int = exc.status_code |
| 144 | + except InvalidWebSocketStatus as exc: |
| 145 | + status_code: int = get_status_code(exc) |
146 | 146 | if status_code == 401: |
147 | 147 | raise ApiError( |
148 | 148 | status_code=status_code, |
@@ -278,8 +278,8 @@ async def connect( |
278 | 278 | try: |
279 | 279 | async with websockets_client_connect(ws_url, extra_headers=headers) as protocol: |
280 | 280 | yield AsyncChatSocketClient(websocket=protocol) |
281 | | - except websockets.exceptions.InvalidStatusCode as exc: |
282 | | - status_code: int = exc.status_code |
| 281 | + except InvalidWebSocketStatus as exc: |
| 282 | + status_code: int = get_status_code(exc) |
283 | 283 | if status_code == 401: |
284 | 284 | raise ApiError( |
285 | 285 | status_code=status_code, |
|
0 commit comments