Skip to content

Commit be6a568

Browse files
committed
[py] Implement add/remove_request_handler
1 parent aa883ce commit be6a568

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

py/selenium/webdriver/common/bidi/network.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,14 @@
1919
import typing
2020
from dataclasses import dataclass, fields, is_dataclass
2121

22+
from selenium.webdriver.common.bidi.cdp import import_devtools
23+
from selenium.webdriver.common.bidi.session import session_subscribe, session_unsubscribe
24+
2225
from . import script
2326

27+
devtools = import_devtools("")
28+
event_class = devtools.util.event_class
29+
2430

2531
@dataclass
2632
class StringValue:
@@ -500,3 +506,32 @@ def from_json(cls, json):
500506
def cmd(self):
501507
result = yield self.to_json()
502508
return result
509+
510+
511+
class Network:
512+
def __init__(self, conn):
513+
514+
self.conn = conn
515+
self.callbacks = {}
516+
517+
async def add_intercept(self, event, params: AddInterceptParameters):
518+
await self.conn.execute(session_subscribe(event.event_class))
519+
result = await self.conn.execute(AddIntercept(params).cmd())
520+
return result
521+
522+
async def add_listener(self, event, callback):
523+
listener = self.conn.listen(event)
524+
525+
async for event in listener:
526+
request_data = BeforeRequestSentParameters.from_json(
527+
event.to_json()["params"]
528+
)
529+
await callback(request_data)
530+
531+
async def continue_request(self, params: ContinueRequestParameters):
532+
result = await self.conn.execute(ContinueRequest(params).cmd())
533+
return result
534+
535+
async def remove_intercept(self, event, params: RemoveInterceptParameters):
536+
await self.conn.execute(session_unsubscribe(event.event_class))
537+
await self.conn.execute(RemoveIntercept(params).cmd())
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Licensed to the Software Freedom Conservancy (SFC) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The SFC licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from dataclasses import fields
18+
19+
import trio
20+
21+
from selenium.webdriver.common.bidi import network
22+
from selenium.webdriver.common.bidi.browsing_context import Navigate
23+
from selenium.webdriver.common.bidi.browsing_context import NavigateParameters
24+
from selenium.webdriver.common.bidi.network import AddInterceptParameters
25+
from selenium.webdriver.common.bidi.network import BeforeRequestSent
26+
from selenium.webdriver.common.bidi.network import BeforeRequestSentParameters
27+
from selenium.webdriver.common.bidi.network import ContinueRequestParameters
28+
29+
30+
def default_request_handler(params: BeforeRequestSentParameters):
31+
return ContinueRequestParameters(request=params.request["request"])
32+
33+
34+
class Network:
35+
def __init__(self, driver):
36+
self.network = None
37+
self.driver = driver
38+
self.intercept = None
39+
self.scope = None
40+
41+
async def add_request_handler(
42+
self, request_filter=lambda _: True, handler=default_request_handler, conn=None
43+
):
44+
with trio.CancelScope() as scope:
45+
self.scope = scope
46+
self.network = network.Network(conn)
47+
params = AddInterceptParameters(["beforeRequestSent"])
48+
callback = self._callback(request_filter, handler)
49+
result = await self.network.add_intercept(
50+
event=BeforeRequestSent, params=params
51+
)
52+
intercept = result["intercept"]
53+
self.intercept = intercept
54+
await self.network.add_listener(event=BeforeRequestSent, callback=callback)
55+
return intercept
56+
57+
async def get(self, url, conn):
58+
params = NavigateParameters(context=self.driver.current_window_handle, url=url)
59+
await conn.execute(Navigate(params).cmd())
60+
61+
async def remove_request_handler(self):
62+
await self.network.remove_intercept(
63+
event=BeforeRequestSent,
64+
params=network.RemoveInterceptParameters(self.intercept),
65+
)
66+
self.scope.cancel()
67+
68+
def _callback(self, request_filter, handler):
69+
async def callback(request):
70+
if request_filter(request):
71+
request = handler(request)
72+
else:
73+
request = default_request_handler(request)
74+
await self.network.continue_request(request)
75+
76+
return callback

py/selenium/webdriver/remote/webdriver.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from .file_detector import FileDetector
6060
from .file_detector import LocalFileDetector
6161
from .mobile import Mobile
62+
from .network import Network
6263
from .remote_connection import RemoteConnection
6364
from .script_key import ScriptKey
6465
from .shadowroot import ShadowRoot
@@ -213,6 +214,7 @@ def __init__(
213214

214215
self._websocket_connection = None
215216
self._script = None
217+
self._network = None
216218

217219
def __repr__(self):
218220
return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>'
@@ -1080,6 +1082,13 @@ def script(self):
10801082

10811083
return self._script
10821084

1085+
@property
1086+
def network(self):
1087+
if not self._network:
1088+
self._network = Network(self)
1089+
1090+
return self._network
1091+
10831092
def _start_bidi(self):
10841093
if self.caps.get("webSocketUrl"):
10851094
ws_url = self.caps.get("webSocketUrl")

0 commit comments

Comments
 (0)