Skip to content

Commit 73aa38a

Browse files
committed
[py] Implement add/remove_request_handler
1 parent 33a2d91 commit 73aa38a

File tree

3 files changed

+191
-0
lines changed

3 files changed

+191
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Licensed to the Software Freedom Conservancy (SFC) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The SFC licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from collections import defaultdict
18+
from contextlib import asynccontextmanager
19+
20+
import trio
21+
22+
from selenium.webdriver.common.bidi import network
23+
from selenium.webdriver.common.bidi.browsing_context import Navigate
24+
from selenium.webdriver.common.bidi.browsing_context import NavigateParameters
25+
from selenium.webdriver.common.bidi.cdp import open_cdp
26+
from selenium.webdriver.common.bidi.network import AddInterceptParameters
27+
from selenium.webdriver.common.bidi.network import BeforeRequestSent
28+
from selenium.webdriver.common.bidi.network import BeforeRequestSentParameters
29+
from selenium.webdriver.common.bidi.network import ContinueRequestParameters
30+
from selenium.webdriver.common.bidi.session import session_subscribe
31+
from selenium.webdriver.common.bidi.session import session_unsubscribe
32+
33+
34+
class Network:
35+
def __init__(self, driver):
36+
self.driver = driver
37+
self.listeners = {}
38+
self.intercepts = defaultdict(lambda: {"event_name": None, "handlers": []})
39+
self.bidi_network = None
40+
self.conn = None
41+
self.nursery = None
42+
43+
self.remove_request_handler = self.remove_intercept
44+
self.clear_request_handlers = self.clear_intercepts
45+
46+
@asynccontextmanager
47+
async def set_context(self):
48+
ws_url = self.driver.caps.get("webSocketUrl")
49+
async with open_cdp(ws_url) as conn:
50+
self.conn = conn
51+
self.bidi_network = network.Network(self.conn)
52+
async with trio.open_nursery() as nursery:
53+
self.nursery = nursery
54+
yield
55+
56+
async def get(self, url, wait="complete"):
57+
params = NavigateParameters(context=self.driver.current_window_handle, url=url, wait=wait)
58+
await self.conn.execute(Navigate(params).cmd())
59+
60+
async def add_listener(self, event, callback):
61+
event_name = event.event_class
62+
if event_name in self.listeners:
63+
return
64+
self.listeners[event_name] = self.conn.listen(event)
65+
try:
66+
async for event in self.listeners[event_name]:
67+
request_data = event.params
68+
if request_data.isBlocked:
69+
await callback(request_data)
70+
except trio.ClosedResourceError:
71+
pass
72+
73+
async def add_handler(self, event, handler, urlPatterns=None):
74+
event_name = event.event_class
75+
phase_name = event_name.split(".")[-1]
76+
77+
await self.conn.execute(session_subscribe(event_name))
78+
79+
params = AddInterceptParameters(phases=[phase_name], urlPatterns=urlPatterns)
80+
result = await self.bidi_network.add_intercept(params)
81+
intercept = result["intercept"]
82+
83+
self.intercepts[intercept]["event_name"] = event_name
84+
self.intercepts[intercept]["handlers"].append(handler)
85+
self.nursery.start_soon(self.add_listener, event, self.handle_events)
86+
return intercept
87+
88+
async def add_request_handler(self, handler, urlPatterns=None):
89+
intercept = await self.add_handler(BeforeRequestSent, handler, urlPatterns)
90+
return intercept
91+
92+
async def handle_events(self, event_params):
93+
if isinstance(event_params, BeforeRequestSentParameters):
94+
json = self.handle_requests(event_params)
95+
params = ContinueRequestParameters(**json)
96+
await self.bidi_network.continue_request(params)
97+
98+
def handle_requests(self, params):
99+
request = params.request
100+
for intercept in params.intercepts:
101+
for handler in self.intercepts[intercept]["handlers"]:
102+
request = handler(request)
103+
return request
104+
105+
async def remove_listener(self, event_name):
106+
listener = self.listeners.pop(event_name)
107+
listener.close()
108+
109+
async def remove_intercept(self, intercept):
110+
await self.bidi_network.remove_intercept(
111+
params=network.RemoveInterceptParameters(intercept),
112+
)
113+
event_name = self.intercepts.pop(intercept)["event_name"]
114+
remaining = [i for i in self.intercepts.values() if i["event_name"] == event_name]
115+
if len(remaining) == 0:
116+
await self.remove_listener(event_name)
117+
await self.conn.execute(session_unsubscribe(event_name))
118+
119+
async def clear_intercepts(self):
120+
for intercept in self.intercepts:
121+
await self.remove_intercept(intercept)
122+
123+
async def disable_cache(self):
124+
# Bidi 'network.setCacheBehavior' is not implemented in v130
125+
self.driver.execute_cdp_cmd("Network.setCacheDisabled", {"cacheDisabled": True})

py/selenium/webdriver/remote/webdriver.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from .file_detector import LocalFileDetector
6262
from .locator_converter import LocatorConverter
6363
from .mobile import Mobile
64+
from .network import Network
6465
from .remote_connection import RemoteConnection
6566
from .script_key import ScriptKey
6667
from .shadowroot import ShadowRoot
@@ -239,6 +240,7 @@ def __init__(
239240

240241
self._websocket_connection = None
241242
self._script = None
243+
self._network = None
242244

243245
def __repr__(self):
244246
return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>'
@@ -1090,6 +1092,13 @@ def script(self):
10901092

10911093
return self._script
10921094

1095+
@property
1096+
def network(self):
1097+
if not self._network:
1098+
self._network = Network(self)
1099+
1100+
return self._network
1101+
10931102
def _start_bidi(self):
10941103
if self.caps.get("webSocketUrl"):
10951104
ws_url = self.caps.get("webSocketUrl")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Licensed to the Software Freedom Conservancy (SFC) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The SFC licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import pytest
18+
19+
from selenium.webdriver.common.bidi.network import UrlPatternString
20+
21+
22+
@pytest.mark.xfail_firefox
23+
@pytest.mark.xfail_safari
24+
@pytest.mark.xfail_edge
25+
async def test_request_handler(driver, pages):
26+
27+
url1 = pages.url("simpleTest.html")
28+
url2 = pages.url("clicks.html")
29+
url3 = pages.url("formPage.html")
30+
31+
pattern1 = [UrlPatternString(url1)]
32+
pattern2 = [UrlPatternString(url2)]
33+
34+
def request_handler(params):
35+
request = params["request"]
36+
json = {"request": request, "url": url3}
37+
return json
38+
39+
async with driver.network.set_context():
40+
# Multiple intercepts
41+
intercept1 = await driver.network.add_request_handler(request_handler, pattern1)
42+
intercept2 = await driver.network.add_request_handler(request_handler, pattern2)
43+
await driver.network.get(url1)
44+
assert driver.title == "We Leave From Here"
45+
await driver.network.get(url2)
46+
assert driver.title == "We Leave From Here"
47+
48+
# Removal of a single intercept
49+
await driver.network.remove_intercept(intercept2)
50+
await driver.network.get(url2)
51+
assert driver.title == "clicks"
52+
await driver.network.get(url1)
53+
assert driver.title == "We Leave From Here"
54+
55+
await driver.network.remove_intercept(intercept1)
56+
await driver.network.get(url1)
57+
assert driver.title == "Hello WebDriver"

0 commit comments

Comments
 (0)