Skip to content

Commit 9f535ce

Browse files
committed
Add support for multiple intercepts
1 parent 7d825b5 commit 9f535ce

File tree

3 files changed

+121
-74
lines changed

3 files changed

+121
-74
lines changed

py/selenium/webdriver/common/bidi/network.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,20 +201,29 @@ class RemoveIntercept(BidiCommand):
201201
method: typing.Literal["network.removeIntercept"] = "network.removeIntercept"
202202

203203

204+
@dataclass
205+
class SetCacheBehaviorParameters(BidiObject):
206+
cacheBehavior: typing.Literal["default", "bypass"]
207+
contexts: typing.Optional[typing.List[browsing_context.BrowsingContext]] = None
208+
209+
210+
@dataclass
211+
class SetCacheBehavior(BidiCommand):
212+
params: SetCacheBehaviorParameters
213+
method: typing.Literal["network.setCacheBehavior"] = "network.setCacheBehavior"
214+
215+
204216
class Network:
205217
def __init__(self, conn):
206218
self.conn = conn
207-
self.callbacks = {}
208219

209-
async def add_intercept(self, event, params: AddInterceptParameters):
210-
await self.conn.execute(session_subscribe(event.event_class))
220+
async def add_intercept(self, params: AddInterceptParameters):
211221
result = await self.conn.execute(AddIntercept(params).cmd())
212222
return result
213223

214224
async def continue_request(self, params: ContinueRequestParameters):
215225
result = await self.conn.execute(ContinueRequest(params).cmd())
216226
return result
217227

218-
async def remove_intercept(self, event, params: RemoveInterceptParameters):
219-
await self.conn.execute(session_unsubscribe(event.event_class))
228+
async def remove_intercept(self, params: RemoveInterceptParameters):
220229
await self.conn.execute(RemoveIntercept(params).cmd())

py/selenium/webdriver/remote/network.py

Lines changed: 78 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
from collections import defaultdict
18+
1719
import trio
1820

1921
from selenium.webdriver.common.bidi import network
@@ -23,65 +25,93 @@
2325
from selenium.webdriver.common.bidi.network import BeforeRequestSent
2426
from selenium.webdriver.common.bidi.network import BeforeRequestSentParameters
2527
from selenium.webdriver.common.bidi.network import ContinueRequestParameters
26-
27-
28-
def default_request_handler(params: BeforeRequestSentParameters):
29-
return ContinueRequestParameters(request=params.request["request"])
28+
from selenium.webdriver.common.bidi.session import session_subscribe
29+
from selenium.webdriver.common.bidi.session import session_unsubscribe
3030

3131

3232
class Network:
3333
def __init__(self, driver):
34-
self.network = None
3534
self.driver = driver
36-
self.scopes = {}
35+
self.listeners = {}
36+
self.intercepts = defaultdict(lambda: {"event_name": None, "handlers": []})
37+
self.bidi_network = None
3738
self.conn = None
3839

39-
async def get(self, url, conn):
40-
params = NavigateParameters(context=self.driver.current_window_handle, url=url, wait="complete")
40+
self.remove_request_handler = self.remove_intercept
41+
self.clear_request_handlers = self.clear_intercepts
42+
43+
async def get(self, url, conn, wait="complete"):
44+
params = NavigateParameters(context=self.driver.current_window_handle, url=url, wait=wait)
4145
await conn.execute(Navigate(params).cmd())
4246

43-
def create_callback(self, request_filter, handler):
44-
async def callback(request):
45-
if request_filter(request):
46-
request = handler(request)
47-
else:
48-
request = default_request_handler(request)
49-
await self.network.continue_request(request)
47+
async def add_listener(self, event, callback):
48+
event_name = event.event_class
49+
if event_name in self.listeners:
50+
return
51+
self.listeners[event_name] = self.conn.listen(event)
52+
try:
53+
async for event in self.listeners[event_name]:
54+
request_data = event.params
55+
if request_data.isBlocked:
56+
await callback(request_data)
57+
except trio.ClosedResourceError:
58+
pass
5059

51-
return callback
5260

53-
async def add_listener(self, event, callback):
54-
listener = self.conn.listen(event)
55-
56-
async for event in listener:
57-
request_data = BeforeRequestSentParameters.from_json(event.to_json()["params"])
58-
if request_data.isBlocked:
59-
await callback(request_data)
60-
61-
async def add_request_handler(
62-
self,
63-
request_filter=lambda _: True,
64-
handler=default_request_handler,
65-
conn=None,
66-
task_status=trio.TASK_STATUS_IGNORED,
67-
):
61+
async def add_handler(self, event, handler, urlPatterns=None, conn=None, task_status=trio.TASK_STATUS_IGNORED):
6862
if not self.conn:
6963
self.conn = conn
70-
with trio.CancelScope() as scope:
71-
self.network = network.Network(conn)
72-
params = AddInterceptParameters(["beforeRequestSent"])
73-
result = await self.network.add_intercept(event=BeforeRequestSent, params=params)
74-
intercept = result["intercept"]
75-
self.scopes[intercept] = scope
76-
task_status.started(intercept)
77-
callback = self.create_callback(request,filter,handler)
78-
await self.add_listener(event=BeforeRequestSent, callback=callback)
79-
return intercept
80-
81-
async def remove_request_handler(self, intercept):
82-
await self.network.remove_intercept(
83-
event=BeforeRequestSent,
84-
params=network.RemoveInterceptParameters(self.intercept),
64+
self.bidi_network = network.Network(conn)
65+
66+
event_name = event.event_class
67+
phase_name = event_name.split(".")[-1]
68+
69+
await self.conn.execute(session_subscribe(event_name))
70+
71+
params = AddInterceptParameters(phases=[phase_name], urlPatterns=urlPatterns)
72+
result = await self.bidi_network.add_intercept(params)
73+
intercept = result["intercept"]
74+
75+
self.intercepts[intercept]["event_name"] = event_name
76+
self.intercepts[intercept]["handlers"].append(handler)
77+
task_status.started(intercept)
78+
await self.add_listener(event=event, callback=self.handle_events)
79+
80+
async def add_request_handler(self, handler, urlPatterns=None, conn=None, task_status=trio.TASK_STATUS_IGNORED):
81+
intercept = await self.add_handler(BeforeRequestSent, handler, urlPatterns, conn, task_status)
82+
return intercept
83+
84+
async def handle_events(self, event_params):
85+
if isinstance(event_params, BeforeRequestSentParameters):
86+
json = self.handle_requests(event_params)
87+
params = ContinueRequestParameters(**json)
88+
await self.bidi_network.continue_request(params)
89+
90+
def handle_requests(self, params):
91+
request = params.request
92+
for intercept in params.intercepts:
93+
for handler in self.intercepts[intercept]["handlers"]:
94+
request = handler(request)
95+
return request
96+
97+
async def remove_listener(self, event_name):
98+
listener = self.listeners.pop(event_name)
99+
listener.close()
100+
101+
async def remove_intercept(self, intercept):
102+
await self.bidi_network.remove_intercept(
103+
params=network.RemoveInterceptParameters(intercept),
85104
)
86-
self.scopes[intercept].cancel()
87-
self.scopes.pop(intercept)
105+
event_name = self.intercepts.pop(intercept)["event_name"]
106+
remaining = [i for i in self.intercepts.values() if i["event_name"] == event_name]
107+
if len(remaining) == 0:
108+
await self.remove_listener(event_name)
109+
await self.conn.execute(session_unsubscribe(event_name))
110+
111+
async def clear_intercepts(self):
112+
for intercept in self.intercepts:
113+
await self.remove_intercept(intercept)
114+
115+
async def disable_cache(self):
116+
# Bidi 'network.setCacheBehavior' is not implemented in v130
117+
self.driver.execute_cdp_cmd("Network.setCacheDisabled", {"cacheDisabled": True})

py/test/selenium/webdriver/common/bidi_network_tests.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,44 @@
2020
from selenium.webdriver.common.bidi.cdp import open_cdp
2121
from selenium.webdriver.common.bidi.network import BeforeRequestSentParameters
2222
from selenium.webdriver.common.bidi.network import ContinueRequestParameters
23+
from selenium.webdriver.common.bidi.network import UrlPatternString
24+
from selenium.webdriver.common.bidi.browsing_context import BrowsingContext
2325

2426

2527
@pytest.mark.xfail_firefox
2628
@pytest.mark.xfail_safari
2729
@pytest.mark.xfail_edge
28-
async def test_add_request_handler(driver, pages):
30+
async def test_request_handler(driver, pages):
2931

30-
target = pages.url("simpleTest.html")
32+
url1 = pages.url("simpleTest.html")
33+
url2 = pages.url("clicks.html")
34+
url3 = pages.url("formPage.html")
3135

32-
def request_filter(params: BeforeRequestSentParameters):
33-
return params.request["url"] == target
36+
pattern1 = [UrlPatternString(url1)]
37+
pattern2 = [UrlPatternString(url2)]
3438

35-
def request_handler(params: BeforeRequestSentParameters):
36-
request = params.request["request"]
37-
json = {"request": request, "url": pages.url("formPage.html")}
38-
return ContinueRequestParameters(**json)
39+
def request_handler(params):
40+
request = params["request"]
41+
json = {"request": request, "url": url3}
42+
return json
3943

4044
ws_url = driver.caps.get("webSocketUrl")
4145
async with open_cdp(ws_url) as conn:
4246
async with trio.open_nursery() as nursery:
43-
nursery.start_soon(
44-
driver.network.add_request_handler,
45-
request_filter,
46-
request_handler,
47-
conn,
48-
)
49-
await trio.sleep(1)
50-
await driver.network.get(target, conn)
51-
assert "We Leave From Here" == driver.title
52-
await trio.sleep(1)
53-
await driver.network.remove_request_handler()
54-
await driver.network.get(target, conn)
55-
assert "Hello WebDriver" == driver.title
47+
# Multiple intercepts
48+
intercept1 = await nursery.start(driver.network.add_request_handler, request_handler, pattern1, conn)
49+
intercept2 = await nursery.start(driver.network.add_request_handler, request_handler, pattern2, conn)
50+
await driver.network.get(url1, conn)
51+
assert driver.title == "We Leave From Here"
52+
await driver.network.get(url2, conn)
53+
assert driver.title == "We Leave From Here"
54+
55+
# Removal of a single intercept
56+
await driver.network.remove_intercept(intercept2)
57+
await driver.network.get(url2, conn)
58+
assert driver.title == "clicks"
59+
await driver.network.get(url1, conn)
60+
assert driver.title == "We Leave From Here"
61+
62+
await driver.network.remove_intercept(intercept1)
63+
assert driver.title == "We Leave From Here"

0 commit comments

Comments
 (0)