diff --git a/py/selenium/webdriver/common/bidi/bidi.py b/py/selenium/webdriver/common/bidi/bidi.py new file mode 100644 index 0000000000000..219108e088e38 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/bidi.py @@ -0,0 +1,59 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from dataclasses import dataclass +from dataclasses import fields +from dataclasses import is_dataclass +from typing import get_type_hints + + +@dataclass +class BidiObject: + def to_json(self): + json = {} + for field in fields(self): + value = getattr(self, field.name) + if value is None: + continue + if is_dataclass(value): + value = value.to_json() + elif isinstance(value, list): + value = [v.to_json() if hasattr(v, "to_json") else v for v in value] + elif isinstance(value, dict): + value = {k: v.to_json() if hasattr(v, "to_json") else v for k, v in value.items()} + key = field.name[:-1] if field.name.endswith("_") else field.name + json[key] = value + return json + + @classmethod + def from_json(cls, json): + return cls(**json) + + +@dataclass +class BidiEvent(BidiObject): + @classmethod + def from_json(cls, json): + params = get_type_hints(cls)["params"].from_json(json) + return cls(params) + + +@dataclass +class BidiCommand(BidiObject): + def cmd(self): + result = yield self.to_json() + return result diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py new file mode 100644 index 0000000000000..ad94e6146166d --- /dev/null +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -0,0 +1,48 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing +from dataclasses import dataclass + +from .bidi import BidiCommand +from .bidi import BidiObject + +ReadinessState = typing.Literal["none", "interactive", "complete"] + + +@dataclass +class NavigateParameters(BidiObject): + context: str + url: str + wait: typing.Optional[ReadinessState] = None + + +@dataclass +class Navigate(BidiCommand): + params: NavigateParameters + method: typing.Literal["browsingContext.navigate"] = "browsingContext.navigate" + + +Navigation = str + +class BrowsingContext: + def __init__(self, conn): + self.conn = conn + + def navigate(self, params: NavigateParameters): + result = self.conn.execute(NavigateParameters(params).cmd()) + return result diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py new file mode 100644 index 0000000000000..15b0950544b42 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/network.py @@ -0,0 +1,213 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing +from dataclasses import dataclass + +from selenium.webdriver.common.bidi.cdp import import_devtools + +from . import browsing_context +from . import script +from .bidi import BidiCommand +from .bidi import BidiEvent +from .bidi import BidiObject + +devtools = import_devtools("") +event_class = devtools.util.event_class + +InterceptPhase = typing.Literal["beforeRequestSent", "responseStarted", "authRequired"] + + +@dataclass +class UrlPatternPattern(BidiObject): + type_: typing.Literal["pattern"] = "pattern" + protocol: typing.Optional[str] = None + hostname: typing.Optional[str] = None + port: typing.Optional[str] = None + pathname: typing.Optional[str] = None + search: typing.Optional[str] = None + + +@dataclass +class UrlPatternString(BidiObject): + pattern: str + type_: typing.Literal["string"] = "string" + + +UrlPattern = typing.Union[UrlPatternPattern, UrlPatternString] + + +@dataclass +class AddInterceptParameters(BidiObject): + phases: typing.List[InterceptPhase] + contexts: typing.Optional[typing.List[browsing_context.BrowsingContext]] = None + urlPatterns: typing.Optional[typing.List[UrlPattern]] = None + + +@dataclass +class AddIntercept(BidiCommand): + params: AddInterceptParameters + method: typing.Literal["network.addIntercept"] = "network.addIntercept" + + +Request = str + + +@dataclass +class StringValue(BidiObject): + value: str + type_: typing.Literal["string"] = "string" + + +@dataclass +class Base64Value(BidiObject): + value: str + type_: typing.Literal["base64"] = "base64" + + +BytesValue = typing.Union[StringValue, Base64Value] + + +@dataclass +class CookieHeader(BidiObject): + name: str + value: BytesValue + + +@dataclass +class Header(BidiObject): + name: str + value: BytesValue + + +@dataclass +class ContinueRequestParameters(BidiObject): + request: Request + body: typing.Optional[BytesValue] = None + cookies: typing.Optional[typing.List[CookieHeader]] = None + headers: typing.Optional[typing.List[Header]] = None + method: typing.Optional[str] = None + url: typing.Optional[str] = None + + +@dataclass +class ContinueRequest(BidiCommand): + params: ContinueRequestParameters + method: typing.Literal["network.continueRequest"] = "network.continueRequest" + + +Intercept = str + + +@dataclass +class RemoveInterceptParameters(BidiObject): + intercept: Intercept + + +@dataclass +class RemoveIntercept(BidiCommand): + params: RemoveInterceptParameters + method: typing.Literal["network.removeIntercept"] = "network.removeIntercept" + + +SameSite = typing.Literal["strict", "lax", "none"] + + +@dataclass +class Cookie(BidiObject): + name: str + value: BytesValue + domain: str + path: str + size: int + httpOnly: bool + secure: bool + sameSite: SameSite + expiry: typing.Optional[int] = None + + +@dataclass +class FetchTimingInfo(BidiObject): + timeOrigin: float + requestTime: float + redirectStart: float + redirectEnd: float + fetchStart: float + dnsStart: float + dnsEnd: float + connectStart: float + connectEnd: float + tlsStart: float + requestStart: float + responseStart: float + responseEnd: float + + +@dataclass +class RequestData(BidiObject): + request: Request + url: str + method: str + headersSize: int + timings: FetchTimingInfo + headers: typing.Optional[typing.List[Header]] = None + cookies: typing.Optional[typing.List[Cookie]] = None + bodySize: typing.Optional[int] = None + + +@dataclass +class Initiator(BidiObject): + type_: typing.Literal["parser", "script", "preflight", "other"] + columnNumber: typing.Optional[int] = None + lineNumber: typing.Optional[int] = None + stackTrace: typing.Optional[script.StackTrace] = None + request: typing.Optional[Request] = None + + +@dataclass +class BeforeRequestSentParameters(BidiObject): + isBlocked: bool + redirectCount: int + request: RequestData + timestamp: int + initiator: Initiator + context: typing.Optional[browsing_context.BrowsingContext] = None + navigation: typing.Optional[browsing_context.Navigation] = None + intercepts: typing.Optional[typing.List[Intercept]] = None + + +@dataclass +@event_class("network.beforeRequestSent") +class BeforeRequestSent(BidiEvent): + params: BeforeRequestSentParameters + method: typing.Literal["network.beforeRequestSent"] = "network.beforeRequestSent" + + +class Network: + def __init__(self, conn): + self.conn = conn + + def add_intercept(self, params: AddInterceptParameters): + result = self.conn.execute(AddIntercept(params).cmd()) + return result + + def continue_request(self, params: ContinueRequestParameters): + result = self.conn.execute(ContinueRequest(params).cmd()) + return result + + def remove_intercept(self, params: RemoveInterceptParameters): + self.conn.execute(RemoveIntercept(params).cmd()) diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 1dc8d101d670e..733b98cf3ff25 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -17,7 +17,9 @@ from dataclasses import dataclass from typing import List +from typing import Optional +from .bidi import BidiObject from .session import session_subscribe from .session import session_unsubscribe @@ -108,3 +110,16 @@ def from_json(cls, json): stacktrace=json["stackTrace"], type_=json["type"], ) + + +@dataclass +class StackFrame(BidiObject): + columnNumber: int + functionName: str + lineNumber: int + url: str + + +@dataclass +class StackTrace(BidiObject): + callFrames: Optional[List[StackFrame]] = None diff --git a/py/selenium/webdriver/remote/network.py b/py/selenium/webdriver/remote/network.py new file mode 100644 index 0000000000000..83a1a305ba02e --- /dev/null +++ b/py/selenium/webdriver/remote/network.py @@ -0,0 +1,95 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from collections import defaultdict + +from selenium.webdriver.common.bidi import network +from selenium.webdriver.common.bidi.browsing_context import Navigate +from selenium.webdriver.common.bidi.browsing_context import NavigateParameters +from selenium.webdriver.common.bidi.network import AddInterceptParameters +from selenium.webdriver.common.bidi.network import BeforeRequestSent +from selenium.webdriver.common.bidi.network import BeforeRequestSentParameters +from selenium.webdriver.common.bidi.network import ContinueRequestParameters +from selenium.webdriver.common.bidi.session import session_subscribe +from selenium.webdriver.common.bidi.session import session_unsubscribe + + +class Network: + def __init__(self, conn, driver): + self.intercepts = defaultdict(lambda: {"event": None, "handlers": []}) + self.callback_ids = {} + self.driver = driver + self.conn = conn + self.bidi_network = network.Network(self.conn) + + self.remove_request_handler = self.remove_intercept + self.clear_request_handlers = self.clear_intercepts + + def get(self, url, wait="none"): + params = NavigateParameters(context=self.driver.current_window_handle, url=url, wait=wait) + self.conn.execute(Navigate(params).cmd()) + + def add_handler(self, event, handler, urlPatterns=None): + event_name = event.event_class + phase_name = event_name.split(".")[-1] + + self.conn.execute(session_subscribe(event_name)) + + params = AddInterceptParameters(phases=[phase_name], urlPatterns=urlPatterns) + result = self.bidi_network.add_intercept(params) + intercept = result["intercept"] + + self.intercepts[intercept]["event"] = event + self.intercepts[intercept]["handlers"].append(handler) + if not self.callback_ids.get(event_name): + self.callback_ids[event_name] = self.conn.add_callback(event, self.handle_events) + return intercept + + def add_request_handler(self, handler, urlPatterns=None): + intercept = self.add_handler(BeforeRequestSent, handler, urlPatterns) + return intercept + + def handle_events(self, event): + event_params = event.params + if isinstance(event_params, BeforeRequestSentParameters) and event_params.isBlocked: + json = self.handle_requests(event_params) + params = ContinueRequestParameters(**json) + self.bidi_network.continue_request(params) + + def handle_requests(self, params): + request = params.request + for intercept in params.intercepts: + for handler in self.intercepts[intercept]["handlers"]: + request = handler(request) + return request + + def remove_intercept(self, intercept): + self.bidi_network.remove_intercept( + params=network.RemoveInterceptParameters(intercept), + ) + + event = self.intercepts.pop(intercept)["event"] + event_name = event.event_class + + remaining = [i for i in self.intercepts.values() if i["event"].event_class == event_name] + if len(remaining) == 0: + self.conn.execute(session_unsubscribe(event_name)) + callback_id = self.callback_ids.pop(event_name) + self.conn.remove_callback(event, callback_id) + + def clear_intercepts(self): + for intercept in self.intercepts: + self.remove_intercept(intercept) diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index dbcbd9dd37c0c..5e5f4dd470aa0 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -64,6 +64,7 @@ from .file_detector import LocalFileDetector from .locator_converter import LocatorConverter from .mobile import Mobile +from .network import Network from .remote_connection import RemoteConnection from .script_key import ScriptKey from .shadowroot import ShadowRoot @@ -252,6 +253,7 @@ def __init__( self._websocket_connection = None self._script = None + self._network = None def __repr__(self): return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' @@ -1249,6 +1251,16 @@ def script(self): return self._script + @property + def network(self): + if not self._websocket_connection: + self._start_bidi() + + if not self._network: + self._network = Network(self._websocket_connection, self) + + return self._network + def _start_bidi(self): if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 3afbba46d5e1e..34772fdb995f8 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -60,8 +60,9 @@ def execute(self, command): logger.debug(f"-> {data}"[: self._max_log_message_size]) self._ws.send(data) - self._wait_until(lambda: self._id in self._messages) - response = self._messages.pop(self._id) + current_id = self._id + self._wait_until(lambda: current_id in self._messages) + response = self._messages.pop(current_id) if "error" in response: raise Exception(response["error"]) @@ -131,7 +132,7 @@ def _process_message(self, message): if "method" in message: params = message["params"] for callback in self.callbacks.get(message["method"], []): - callback(params) + Thread(target=callback, args=(params,)).start() def _wait_until(self, condition): timeout = self._response_wait_timeout diff --git a/py/test/selenium/webdriver/common/bidi_network_tests.py b/py/test/selenium/webdriver/common/bidi_network_tests.py new file mode 100644 index 0000000000000..edc6e3ddae31a --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_network_tests.py @@ -0,0 +1,54 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from selenium.webdriver.common.bidi.network import UrlPatternString + + +@pytest.mark.xfail_safari +def test_request_handler(driver, pages): + + url1 = pages.url("simpleTest.html") + url2 = pages.url("clicks.html") + url3 = pages.url("formPage.html") + + pattern1 = [UrlPatternString(url1)] + pattern2 = [UrlPatternString(url2)] + + def request_handler(params): + request = params["request"] + json = {"request": request, "url": url3} + return json + + # Multiple intercepts + intercept1 = driver.network.add_request_handler(request_handler, pattern1) + intercept2 = driver.network.add_request_handler(request_handler, pattern2) + driver.network.get(url1) + assert driver.title == "We Leave From Here" + driver.network.get(url2) + assert driver.title == "We Leave From Here" + + # Removal of a single intercept + driver.network.remove_intercept(intercept2) + driver.network.get(url2) + assert driver.title == "clicks" + driver.network.get(url1) + assert driver.title == "We Leave From Here" + + driver.network.remove_intercept(intercept1) + driver.network.get(url1) + assert driver.title == "Hello WebDriver"