| 
 | 1 | +# Licensed to the Software Freedom Conservancy (SFC) under one  | 
 | 2 | +# or more contributor license agreements.  See the NOTICE file  | 
 | 3 | +# distributed with this work for additional information  | 
 | 4 | +# regarding copyright ownership.  The SFC licenses this file  | 
 | 5 | +# to you under the Apache License, Version 2.0 (the  | 
 | 6 | +# "License"); you may not use this file except in compliance  | 
 | 7 | +# with the License.  You may obtain a copy of the License at  | 
 | 8 | +#  | 
 | 9 | +#   http://www.apache.org/licenses/LICENSE-2.0  | 
 | 10 | +#  | 
 | 11 | +# Unless required by applicable law or agreed to in writing,  | 
 | 12 | +# software distributed under the License is distributed on an  | 
 | 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY  | 
 | 14 | +# KIND, either express or implied.  See the License for the  | 
 | 15 | +# specific language governing permissions and limitations  | 
 | 16 | +# under the License.  | 
 | 17 | +from collections import defaultdict  | 
 | 18 | +from contextlib import asynccontextmanager  | 
 | 19 | + | 
 | 20 | +import trio  | 
 | 21 | + | 
 | 22 | +from selenium.webdriver.common.bidi import network  | 
 | 23 | +from selenium.webdriver.common.bidi.browsing_context import Navigate  | 
 | 24 | +from selenium.webdriver.common.bidi.browsing_context import NavigateParameters  | 
 | 25 | +from selenium.webdriver.common.bidi.cdp import open_cdp  | 
 | 26 | +from selenium.webdriver.common.bidi.network import AddInterceptParameters  | 
 | 27 | +from selenium.webdriver.common.bidi.network import BeforeRequestSent  | 
 | 28 | +from selenium.webdriver.common.bidi.network import BeforeRequestSentParameters  | 
 | 29 | +from selenium.webdriver.common.bidi.network import ContinueRequestParameters  | 
 | 30 | +from selenium.webdriver.common.bidi.session import session_subscribe  | 
 | 31 | +from selenium.webdriver.common.bidi.session import session_unsubscribe  | 
 | 32 | + | 
 | 33 | + | 
 | 34 | +class Network:  | 
 | 35 | +    def __init__(self, driver):  | 
 | 36 | +        self.driver = driver  | 
 | 37 | +        self.listeners = {}  | 
 | 38 | +        self.intercepts = defaultdict(lambda: {"event_name": None, "handlers": []})  | 
 | 39 | +        self.bidi_network = None  | 
 | 40 | +        self.conn = None  | 
 | 41 | +        self.nursery = None  | 
 | 42 | + | 
 | 43 | +        self.remove_request_handler = self.remove_intercept  | 
 | 44 | +        self.clear_request_handlers = self.clear_intercepts  | 
 | 45 | + | 
 | 46 | +    @asynccontextmanager  | 
 | 47 | +    async def set_context(self):  | 
 | 48 | +        ws_url = self.driver.caps.get("webSocketUrl")  | 
 | 49 | +        async with open_cdp(ws_url) as conn:  | 
 | 50 | +            self.conn = conn  | 
 | 51 | +            self.bidi_network = network.Network(self.conn)  | 
 | 52 | +            async with trio.open_nursery() as nursery:  | 
 | 53 | +                self.nursery = nursery  | 
 | 54 | +                yield  | 
 | 55 | + | 
 | 56 | +    async def get(self, url, wait="complete"):  | 
 | 57 | +        params = NavigateParameters(context=self.driver.current_window_handle, url=url, wait=wait)  | 
 | 58 | +        await self.conn.execute(Navigate(params).cmd())  | 
 | 59 | + | 
 | 60 | +    async def add_listener(self, event, callback):  | 
 | 61 | +        event_name = event.event_class  | 
 | 62 | +        if event_name in self.listeners:  | 
 | 63 | +            return  | 
 | 64 | +        self.listeners[event_name] = self.conn.listen(event)  | 
 | 65 | +        try:  | 
 | 66 | +            async for event in self.listeners[event_name]:  | 
 | 67 | +                request_data = event.params  | 
 | 68 | +                if request_data.isBlocked:  | 
 | 69 | +                    await callback(request_data)  | 
 | 70 | +        except trio.ClosedResourceError:  | 
 | 71 | +            pass  | 
 | 72 | + | 
 | 73 | +    async def add_handler(self, event, handler, urlPatterns=None):  | 
 | 74 | +        event_name = event.event_class  | 
 | 75 | +        phase_name = event_name.split(".")[-1]  | 
 | 76 | + | 
 | 77 | +        await self.conn.execute(session_subscribe(event_name))  | 
 | 78 | + | 
 | 79 | +        params = AddInterceptParameters(phases=[phase_name], urlPatterns=urlPatterns)  | 
 | 80 | +        result = await self.bidi_network.add_intercept(params)  | 
 | 81 | +        intercept = result["intercept"]  | 
 | 82 | + | 
 | 83 | +        self.intercepts[intercept]["event_name"] = event_name  | 
 | 84 | +        self.intercepts[intercept]["handlers"].append(handler)  | 
 | 85 | +        self.nursery.start_soon(self.add_listener, event, self.handle_events)  | 
 | 86 | +        return intercept  | 
 | 87 | + | 
 | 88 | +    async def add_request_handler(self, handler, urlPatterns=None):  | 
 | 89 | +        intercept = await self.add_handler(BeforeRequestSent, handler, urlPatterns)  | 
 | 90 | +        return intercept  | 
 | 91 | + | 
 | 92 | +    async def handle_events(self, event_params):  | 
 | 93 | +        if isinstance(event_params, BeforeRequestSentParameters):  | 
 | 94 | +            json = self.handle_requests(event_params)  | 
 | 95 | +            params = ContinueRequestParameters(**json)  | 
 | 96 | +            await self.bidi_network.continue_request(params)  | 
 | 97 | + | 
 | 98 | +    def handle_requests(self, params):  | 
 | 99 | +        request = params.request  | 
 | 100 | +        for intercept in params.intercepts:  | 
 | 101 | +            for handler in self.intercepts[intercept]["handlers"]:  | 
 | 102 | +                request = handler(request)  | 
 | 103 | +        return request  | 
 | 104 | + | 
 | 105 | +    async def remove_listener(self, event_name):  | 
 | 106 | +        listener = self.listeners.pop(event_name)  | 
 | 107 | +        listener.close()  | 
 | 108 | + | 
 | 109 | +    async def remove_intercept(self, intercept):  | 
 | 110 | +        await self.bidi_network.remove_intercept(  | 
 | 111 | +            params=network.RemoveInterceptParameters(intercept),  | 
 | 112 | +        )  | 
 | 113 | +        event_name = self.intercepts.pop(intercept)["event_name"]  | 
 | 114 | +        remaining = [i for i in self.intercepts.values() if i["event_name"] == event_name]  | 
 | 115 | +        if len(remaining) == 0:  | 
 | 116 | +            await self.remove_listener(event_name)  | 
 | 117 | +            await self.conn.execute(session_unsubscribe(event_name))  | 
 | 118 | + | 
 | 119 | +    async def clear_intercepts(self):  | 
 | 120 | +        for intercept in self.intercepts:  | 
 | 121 | +            await self.remove_intercept(intercept)  | 
 | 122 | + | 
 | 123 | +    async def disable_cache(self):  | 
 | 124 | +        # Bidi 'network.setCacheBehavior' is not implemented in v130  | 
 | 125 | +        self.driver.execute_cdp_cmd("Network.setCacheDisabled", {"cacheDisabled": True})  | 
0 commit comments