diff --git a/cdp_use/client.py b/cdp_use/client.py index 3068c2c..ca5b6e4 100644 --- a/cdp_use/client.py +++ b/cdp_use/client.py @@ -3,7 +3,7 @@ import logging import re import time -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Callable, List, Optional import websockets @@ -223,6 +223,8 @@ def __init__(self, url: str, additional_headers: Optional[Dict[str, str]] = None self.pending_requests: Dict[int, asyncio.Future] = {} self._message_handler_task = None # self.event_handlers: Dict[str, Callable] = {} + # WebSocket routing/interception + self._web_socket_routes: List[Dict[str, Any]] = [] # Initialize the type-safe CDP library from cdp_use.cdp.library import CDPLibrary @@ -259,7 +261,7 @@ async def start(self): } if self.additional_headers: connect_kwargs["additional_headers"] = self.additional_headers - self.ws = await websockets.connect(self.url, **connect_kwargs) + self.ws = await self._connect_with_redirects(self.url, self.additional_headers) self._message_handler_task = asyncio.create_task(self._handle_messages()) async def stop(self): @@ -376,3 +378,32 @@ async def send_raw( # Wait for the response return await future + +async def _connect_with_redirects(self, url, headers=None, max_redirects=3): + """ + Connect to WebSocket, following HTTP redirects (301, 302, 307, 308). + """ + connect_kwargs = {"max_size": 100 * 1024 * 1024} + if headers: + connect_kwargs["additional_headers"] = headers + + for _ in range(max_redirects): + try: + ws = await websockets.connect(url, **connect_kwargs) + return ws + except websockets.InvalidHandshake as e: + response = getattr(e, 'response', None) + status = getattr(response, 'status_code', None) + if response and status in (301, 302, 307, 308): + location = response.headers.get('Location') or response.headers.get('location') + if not location: + raise RuntimeError(f"Redirect status {status} but no Location header") + url = location + continue + if response and status in (401, 403): + reason = getattr(response, 'reason_phrase', '') + raise RuntimeError(f"WebSocket authentication failed: {status} {reason}") + raise + raise RuntimeError("Too many redirects or failed to connect") + +