Skip to content

Commit c674e06

Browse files
authored
Merge branch 'trunk' into rb-fix-lint
2 parents 69e5849 + a39a168 commit c674e06

File tree

12 files changed

+582
-42
lines changed

12 files changed

+582
-42
lines changed

py/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def fin():
154154
if driver_instance is None:
155155
if driver_class == "Firefox":
156156
options = get_options(driver_class, request.config)
157+
# There are issues with window size/position when running Firefox
158+
# under Wayland, so we use XWayland instead.
159+
os.environ["MOZ_ENABLE_WAYLAND"] = "0"
157160
if driver_class == "Chrome":
158161
options = get_options(driver_class, request.config)
159162
if driver_class == "Edge":
@@ -166,6 +169,9 @@ def fin():
166169
options = get_options("Firefox", request.config) or webdriver.FirefoxOptions()
167170
options.set_capability("moz:firefoxOptions", {})
168171
options.enable_downloads = True
172+
# There are issues with window size/position when running Firefox
173+
# under Wayland, so we use XWayland instead.
174+
os.environ["MOZ_ENABLE_WAYLAND"] = "0"
169175
if driver_path is not None:
170176
kwargs["service"] = get_service(driver_class, driver_path)
171177
if options is not None:
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# Licensed to the Software Freedom Conservancy (SFC) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The SFC licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
19+
class NetworkEvent:
20+
"""Represents a network event."""
21+
22+
def __init__(self, event_class, **kwargs):
23+
self.event_class = event_class
24+
self.params = kwargs
25+
26+
@classmethod
27+
def from_json(cls, json):
28+
return cls(event_class=json.get("event_class"), **json)
29+
30+
31+
class Network:
32+
EVENTS = {
33+
"before_request": "network.beforeRequestSent",
34+
"response_started": "network.responseStarted",
35+
"response_completed": "network.responseCompleted",
36+
"auth_required": "network.authRequired",
37+
"fetch_error": "network.fetchError",
38+
"continue_request": "network.continueRequest",
39+
"continue_auth": "network.continueWithAuth",
40+
}
41+
42+
PHASES = {
43+
"before_request": "beforeRequestSent",
44+
"response_started": "responseStarted",
45+
"auth_required": "authRequired",
46+
}
47+
48+
def __init__(self, conn):
49+
self.conn = conn
50+
self.intercepts = []
51+
self.callbacks = {}
52+
self.subscriptions = {}
53+
54+
def command_builder(self, method, params):
55+
"""Build a command iterator to send to the network.
56+
57+
Parameters:
58+
----------
59+
method (str): The method to execute.
60+
params (dict): The parameters to pass to the method.
61+
"""
62+
command = {"method": method, "params": params}
63+
cmd = yield command
64+
return cmd
65+
66+
def _add_intercept(self, phases=[], contexts=None, url_patterns=None):
67+
"""Add an intercept to the network.
68+
69+
Parameters:
70+
----------
71+
phases (list, optional): A list of phases to intercept.
72+
Default is empty list.
73+
contexts (list, optional): A list of contexts to intercept.
74+
Default is None.
75+
url_patterns (list, optional): A list of URL patterns to intercept.
76+
Default is None.
77+
78+
Returns:
79+
-------
80+
str : intercept id
81+
"""
82+
params = {}
83+
if contexts is not None:
84+
params["contexts"] = contexts
85+
if url_patterns is not None:
86+
params["urlPatterns"] = url_patterns
87+
if len(phases) > 0:
88+
params["phases"] = phases
89+
else:
90+
params["phases"] = ["beforeRequestSent"]
91+
cmd = self.command_builder("network.addIntercept", params)
92+
93+
result = self.conn.execute(cmd)
94+
self.intercepts.append(result["intercept"])
95+
return result
96+
97+
def _remove_intercept(self, intercept=None):
98+
"""Remove a specific intercept, or all intercepts.
99+
100+
Parameters:
101+
----------
102+
intercept (str, optional): The intercept to remove.
103+
Default is None.
104+
105+
Raises:
106+
------
107+
Exception: If intercept is not found.
108+
109+
Notes:
110+
-----
111+
If intercept is None, all intercepts will be removed.
112+
"""
113+
if intercept is None:
114+
intercepts_to_remove = self.intercepts.copy() # create a copy before iterating
115+
for intercept_id in intercepts_to_remove: # remove all intercepts
116+
self.conn.execute(self.command_builder("network.removeIntercept", {"intercept": intercept_id}))
117+
self.intercepts.remove(intercept_id)
118+
else:
119+
try:
120+
self.conn.execute(self.command_builder("network.removeIntercept", {"intercept": intercept}))
121+
self.intercepts.remove(intercept)
122+
except Exception as e:
123+
raise Exception(f"Exception: {e}")
124+
125+
def _on_request(self, event_name, callback):
126+
"""Set a callback function to subscribe to a network event.
127+
128+
Parameters:
129+
----------
130+
event_name (str): The event to subscribe to.
131+
callback (function): The callback function to execute on event.
132+
Takes Request object as argument.
133+
134+
Returns:
135+
-------
136+
int : callback id
137+
"""
138+
139+
event = NetworkEvent(event_name)
140+
141+
def _callback(event_data):
142+
request = Request(
143+
network=self,
144+
request_id=event_data.params["request"].get("request", None),
145+
body_size=event_data.params["request"].get("bodySize", None),
146+
cookies=event_data.params["request"].get("cookies", None),
147+
resource_type=event_data.params["request"].get("goog:resourceType", None),
148+
headers_size=event_data.params["request"].get("headersSize", None),
149+
timings=event_data.params["request"].get("timings", None),
150+
url=event_data.params["request"].get("url", None),
151+
)
152+
callback(request)
153+
154+
callback_id = self.conn.add_callback(event, _callback)
155+
156+
if event_name in self.callbacks:
157+
self.callbacks[event_name].append(callback_id)
158+
else:
159+
self.callbacks[event_name] = [callback_id]
160+
161+
return callback_id
162+
163+
def add_request_handler(self, event, callback, url_patterns=None, contexts=None):
164+
"""Add a request handler to the network.
165+
166+
Parameters:
167+
----------
168+
event (str): The event to subscribe to.
169+
url_patterns (list, optional): A list of URL patterns to intercept.
170+
Default is None.
171+
contexts (list, optional): A list of contexts to intercept.
172+
Default is None.
173+
callback (function): The callback function to execute on request interception
174+
Takes Request object as argument.
175+
176+
Returns:
177+
-------
178+
int : callback id
179+
"""
180+
181+
try:
182+
event_name = self.EVENTS[event]
183+
phase_name = self.PHASES[event]
184+
except KeyError:
185+
raise Exception(f"Event {event} not found")
186+
187+
result = self._add_intercept(phases=[phase_name], url_patterns=url_patterns, contexts=contexts)
188+
callback_id = self._on_request(event_name, callback)
189+
190+
if event_name in self.subscriptions:
191+
self.subscriptions[event_name].append(callback_id)
192+
else:
193+
params = {}
194+
params["events"] = [event_name]
195+
self.conn.execute(self.command_builder("session.subscribe", params))
196+
self.subscriptions[event_name] = [callback_id]
197+
198+
self.callbacks[callback_id] = result["intercept"]
199+
return callback_id
200+
201+
def remove_request_handler(self, event, callback_id):
202+
"""Remove a request handler from the network.
203+
204+
Parameters:
205+
----------
206+
event_name (str): The event to unsubscribe from.
207+
callback_id (int): The callback id to remove.
208+
"""
209+
try:
210+
event_name = self.EVENTS[event]
211+
except KeyError:
212+
raise Exception(f"Event {event} not found")
213+
214+
net_event = NetworkEvent(event_name)
215+
216+
self.conn.remove_callback(net_event, callback_id)
217+
self._remove_intercept(self.callbacks[callback_id])
218+
del self.callbacks[callback_id]
219+
self.subscriptions[event_name].remove(callback_id)
220+
if len(self.subscriptions[event_name]) == 0:
221+
params = {}
222+
params["events"] = [event_name]
223+
self.conn.execute(self.command_builder("session.unsubscribe", params))
224+
del self.subscriptions[event_name]
225+
226+
def clear_request_handlers(self):
227+
"""Clear all request handlers from the network."""
228+
229+
for event_name in self.subscriptions:
230+
net_event = NetworkEvent(event_name)
231+
for callback_id in self.subscriptions[event_name]:
232+
self.conn.remove_callback(net_event, callback_id)
233+
self._remove_intercept(self.callbacks[callback_id])
234+
del self.callbacks[callback_id]
235+
params = {}
236+
params["events"] = [event_name]
237+
self.conn.execute(self.command_builder("session.unsubscribe", params))
238+
self.subscriptions = {}
239+
240+
def add_auth_handler(self, username, password):
241+
"""Add an authentication handler to the network.
242+
243+
Parameters:
244+
----------
245+
username (str): The username to authenticate with.
246+
password (str): The password to authenticate with.
247+
248+
Returns:
249+
-------
250+
int : callback id
251+
"""
252+
event = "auth_required"
253+
254+
def _callback(request):
255+
request._continue_with_auth(username, password)
256+
257+
return self.add_request_handler(event, _callback)
258+
259+
def remove_auth_handler(self, callback_id):
260+
"""Remove an authentication handler from the network.
261+
262+
Parameters:
263+
----------
264+
callback_id (int): The callback id to remove.
265+
"""
266+
event = "auth_required"
267+
self.remove_request_handler(event, callback_id)
268+
269+
270+
class Request:
271+
"""Represents an intercepted network request."""
272+
273+
def __init__(
274+
self,
275+
network: Network,
276+
request_id,
277+
body_size=None,
278+
cookies=None,
279+
resource_type=None,
280+
headers=None,
281+
headers_size=None,
282+
method=None,
283+
timings=None,
284+
url=None,
285+
):
286+
self.network = network
287+
self.request_id = request_id
288+
self.body_size = body_size
289+
self.cookies = cookies
290+
self.resource_type = resource_type
291+
self.headers = headers
292+
self.headers_size = headers_size
293+
self.method = method
294+
self.timings = timings
295+
self.url = url
296+
297+
def command_builder(self, method, params):
298+
"""Build a command iterator to send to the network.
299+
300+
Parameters:
301+
----------
302+
method (str): The method to execute.
303+
params (dict): The parameters to pass to the method.
304+
"""
305+
command = {"method": method, "params": params}
306+
cmd = yield command
307+
return cmd
308+
309+
def fail_request(self):
310+
"""Fail this request."""
311+
312+
if not self.request_id:
313+
raise ValueError("Request not found.")
314+
315+
params = {"request": self.request_id}
316+
self.network.conn.execute(self.command_builder("network.failRequest", params))
317+
318+
def continue_request(self, body=None, method=None, headers=None, cookies=None, url=None):
319+
"""Continue after intercepting this request."""
320+
321+
if not self.request_id:
322+
raise ValueError("Request not found.")
323+
324+
params = {"request": self.request_id}
325+
if body is not None:
326+
params["body"] = body
327+
if method is not None:
328+
params["method"] = method
329+
if headers is not None:
330+
params["headers"] = headers
331+
if cookies is not None:
332+
params["cookies"] = cookies
333+
if url is not None:
334+
params["url"] = url
335+
336+
self.network.conn.execute(self.command_builder("network.continueRequest", params))
337+
338+
def _continue_with_auth(self, username=None, password=None):
339+
"""Continue with authentication.
340+
341+
Parameters:
342+
----------
343+
request (Request): The request to continue with.
344+
username (str): The username to authenticate with.
345+
password (str): The password to authenticate with.
346+
347+
Notes:
348+
-----
349+
If username or password is None, it attempts auth with no credentials
350+
"""
351+
352+
params = {}
353+
params["request"] = self.request_id
354+
355+
if not username or not password: # no credentials is valid option
356+
params["action"] = "default"
357+
else:
358+
params["action"] = "provideCredentials"
359+
params["credentials"] = {"type": "password", "username": username, "password": password}
360+
361+
self.network.conn.execute(self.command_builder("network.continueWithAuth", params))

0 commit comments

Comments
 (0)