| 
14 | 14 | # KIND, either express or implied.  See the License for the  | 
15 | 15 | # specific language governing permissions and limitations  | 
16 | 16 | # under the License.  | 
 | 17 | +from collections import defaultdict  | 
 | 18 | + | 
17 | 19 | import trio  | 
18 | 20 | 
 
  | 
19 | 21 | from selenium.webdriver.common.bidi import network  | 
 | 
23 | 25 | from selenium.webdriver.common.bidi.network import BeforeRequestSent  | 
24 | 26 | from selenium.webdriver.common.bidi.network import BeforeRequestSentParameters  | 
25 | 27 | from selenium.webdriver.common.bidi.network import ContinueRequestParameters  | 
26 |  | - | 
27 |  | - | 
28 |  | -def default_request_handler(params: BeforeRequestSentParameters):  | 
29 |  | -    return ContinueRequestParameters(request=params.request["request"])  | 
 | 28 | +from selenium.webdriver.common.bidi.session import session_subscribe  | 
 | 29 | +from selenium.webdriver.common.bidi.session import session_unsubscribe  | 
30 | 30 | 
 
  | 
31 | 31 | 
 
  | 
32 | 32 | class Network:  | 
33 | 33 |     def __init__(self, driver):  | 
34 |  | -        self.network = None  | 
35 | 34 |         self.driver = driver  | 
36 |  | -        self.scopes = {}  | 
 | 35 | +        self.listeners = {}  | 
 | 36 | +        self.intercepts = defaultdict(lambda: {"event_name": None, "handlers": []})  | 
 | 37 | +        self.bidi_network = None  | 
37 | 38 |         self.conn = None  | 
38 | 39 | 
 
  | 
39 |  | -    async def get(self, url, conn):  | 
40 |  | -        params = NavigateParameters(context=self.driver.current_window_handle, url=url, wait="complete")  | 
 | 40 | +        self.remove_request_handler = self.remove_intercept  | 
 | 41 | +        self.clear_request_handlers = self.clear_intercepts  | 
 | 42 | + | 
 | 43 | +    async def get(self, url, conn, wait="complete"):  | 
 | 44 | +        params = NavigateParameters(context=self.driver.current_window_handle, url=url, wait=wait)  | 
41 | 45 |         await conn.execute(Navigate(params).cmd())  | 
42 | 46 | 
 
  | 
43 |  | -    def create_callback(self, request_filter, handler):  | 
44 |  | -        async def callback(request):  | 
45 |  | -            if request_filter(request):  | 
46 |  | -                request = handler(request)  | 
47 |  | -            else:  | 
48 |  | -                request = default_request_handler(request)  | 
49 |  | -            await self.network.continue_request(request)  | 
 | 47 | +    async def add_listener(self, event, callback):  | 
 | 48 | +        event_name = event.event_class  | 
 | 49 | +        if event_name in self.listeners:  | 
 | 50 | +            return  | 
 | 51 | +        self.listeners[event_name] = self.conn.listen(event)  | 
 | 52 | +        try:  | 
 | 53 | +            async for event in self.listeners[event_name]:  | 
 | 54 | +                request_data = event.params  | 
 | 55 | +                if request_data.isBlocked:  | 
 | 56 | +                    await callback(request_data)  | 
 | 57 | +        except trio.ClosedResourceError:  | 
 | 58 | +            pass  | 
50 | 59 | 
 
  | 
51 |  | -        return callback  | 
52 | 60 | 
 
  | 
53 |  | -    async def add_listener(self, event, callback):  | 
54 |  | -        listener = self.conn.listen(event)  | 
55 |  | - | 
56 |  | -        async for event in listener:  | 
57 |  | -            request_data = BeforeRequestSentParameters.from_json(event.to_json()["params"])  | 
58 |  | -            if request_data.isBlocked:  | 
59 |  | -                await callback(request_data)  | 
60 |  | - | 
61 |  | -    async def add_request_handler(  | 
62 |  | -        self,  | 
63 |  | -        request_filter=lambda _: True,  | 
64 |  | -        handler=default_request_handler,  | 
65 |  | -        conn=None,  | 
66 |  | -        task_status=trio.TASK_STATUS_IGNORED,  | 
67 |  | -    ):  | 
 | 61 | +    async def add_handler(self, event, handler, urlPatterns=None, conn=None, task_status=trio.TASK_STATUS_IGNORED):  | 
68 | 62 |         if not self.conn:  | 
69 | 63 |             self.conn = conn  | 
70 |  | -        with trio.CancelScope() as scope:  | 
71 |  | -            self.network = network.Network(conn)  | 
72 |  | -            params = AddInterceptParameters(["beforeRequestSent"])  | 
73 |  | -            result = await self.network.add_intercept(event=BeforeRequestSent, params=params)  | 
74 |  | -            intercept = result["intercept"]  | 
75 |  | -            self.scopes[intercept] = scope  | 
76 |  | -            task_status.started(intercept)  | 
77 |  | -            callback = self.create_callback(request,filter,handler)  | 
78 |  | -            await self.add_listener(event=BeforeRequestSent, callback=callback)  | 
79 |  | -            return intercept  | 
80 |  | - | 
81 |  | -    async def remove_request_handler(self, intercept):  | 
82 |  | -        await self.network.remove_intercept(  | 
83 |  | -            event=BeforeRequestSent,  | 
84 |  | -            params=network.RemoveInterceptParameters(self.intercept),  | 
 | 64 | +            self.bidi_network = network.Network(conn)  | 
 | 65 | + | 
 | 66 | +        event_name = event.event_class  | 
 | 67 | +        phase_name = event_name.split(".")[-1]  | 
 | 68 | + | 
 | 69 | +        await self.conn.execute(session_subscribe(event_name))  | 
 | 70 | + | 
 | 71 | +        params = AddInterceptParameters(phases=[phase_name], urlPatterns=urlPatterns)  | 
 | 72 | +        result = await self.bidi_network.add_intercept(params)  | 
 | 73 | +        intercept = result["intercept"]  | 
 | 74 | + | 
 | 75 | +        self.intercepts[intercept]["event_name"] = event_name  | 
 | 76 | +        self.intercepts[intercept]["handlers"].append(handler)  | 
 | 77 | +        task_status.started(intercept)  | 
 | 78 | +        await self.add_listener(event=event, callback=self.handle_events)  | 
 | 79 | + | 
 | 80 | +    async def add_request_handler(self, handler, urlPatterns=None, conn=None, task_status=trio.TASK_STATUS_IGNORED):  | 
 | 81 | +        intercept = await self.add_handler(BeforeRequestSent, handler, urlPatterns, conn, task_status)  | 
 | 82 | +        return intercept  | 
 | 83 | + | 
 | 84 | +    async def handle_events(self, event_params):  | 
 | 85 | +        if isinstance(event_params, BeforeRequestSentParameters):  | 
 | 86 | +            json = self.handle_requests(event_params)  | 
 | 87 | +            params = ContinueRequestParameters(**json)  | 
 | 88 | +            await self.bidi_network.continue_request(params)  | 
 | 89 | + | 
 | 90 | +    def handle_requests(self, params):  | 
 | 91 | +        request = params.request  | 
 | 92 | +        for intercept in params.intercepts:  | 
 | 93 | +            for handler in self.intercepts[intercept]["handlers"]:  | 
 | 94 | +                request = handler(request)  | 
 | 95 | +        return request  | 
 | 96 | + | 
 | 97 | +    async def remove_listener(self, event_name):  | 
 | 98 | +        listener = self.listeners.pop(event_name)  | 
 | 99 | +        listener.close()  | 
 | 100 | + | 
 | 101 | +    async def remove_intercept(self, intercept):  | 
 | 102 | +        await self.bidi_network.remove_intercept(  | 
 | 103 | +            params=network.RemoveInterceptParameters(intercept),  | 
85 | 104 |         )  | 
86 |  | -        self.scopes[intercept].cancel()  | 
87 |  | -        self.scopes.pop(intercept)  | 
 | 105 | +        event_name = self.intercepts.pop(intercept)["event_name"]  | 
 | 106 | +        remaining = [i for i in self.intercepts.values() if i["event_name"] == event_name]  | 
 | 107 | +        if len(remaining) == 0:  | 
 | 108 | +            await self.remove_listener(event_name)  | 
 | 109 | +            await self.conn.execute(session_unsubscribe(event_name))  | 
 | 110 | + | 
 | 111 | +    async def clear_intercepts(self):  | 
 | 112 | +        for intercept in self.intercepts:  | 
 | 113 | +            await self.remove_intercept(intercept)  | 
 | 114 | + | 
 | 115 | +    async def disable_cache(self):  | 
 | 116 | +        # Bidi 'network.setCacheBehavior' is not implemented in v130  | 
 | 117 | +        self.driver.execute_cdp_cmd("Network.setCacheDisabled", {"cacheDisabled": True})  | 
0 commit comments