1515# specific language governing permissions and limitations
1616# under the License.
1717from collections import defaultdict
18+ from contextlib import asynccontextmanager
1819
1920import trio
2021
2122from selenium .webdriver .common .bidi import network
2223from selenium .webdriver .common .bidi .browsing_context import Navigate
2324from selenium .webdriver .common .bidi .browsing_context import NavigateParameters
25+ from selenium .webdriver .common .bidi .cdp import open_cdp
2426from selenium .webdriver .common .bidi .network import AddInterceptParameters
2527from selenium .webdriver .common .bidi .network import BeforeRequestSent
2628from selenium .webdriver .common .bidi .network import BeforeRequestSentParameters
@@ -36,13 +38,24 @@ def __init__(self, driver):
3638 self .intercepts = defaultdict (lambda : {"event_name" : None , "handlers" : []})
3739 self .bidi_network = None
3840 self .conn = None
41+ self .nursery = None
3942
4043 self .remove_request_handler = self .remove_intercept
4144 self .clear_request_handlers = self .clear_intercepts
4245
43- async def get (self , url , conn , wait = "complete" ):
46+ @asynccontextmanager
47+ async def set_context (self ):
48+ ws_url = self .driver .caps .get ("webSocketUrl" )
49+ async with open_cdp (ws_url ) as conn :
50+ self .conn = conn
51+ self .bidi_network = network .Network (conn )
52+ async with trio .open_nursery () as nursery :
53+ self .nursery = nursery
54+ yield
55+
56+ async def get (self , url , wait = "complete" ):
4457 params = NavigateParameters (context = self .driver .current_window_handle , url = url , wait = wait )
45- await conn .execute (Navigate (params ).cmd ())
58+ await self . conn .execute (Navigate (params ).cmd ())
4659
4760 async def add_listener (self , event , callback ):
4861 event_name = event .event_class
@@ -57,11 +70,7 @@ async def add_listener(self, event, callback):
5770 except trio .ClosedResourceError :
5871 pass
5972
60- async def add_handler (self , event , handler , urlPatterns = None , conn = None , task_status = trio .TASK_STATUS_IGNORED ):
61- if not self .conn :
62- self .conn = conn
63- self .bidi_network = network .Network (conn )
64-
73+ async def add_handler (self , event , handler , urlPatterns = None ):
6574 event_name = event .event_class
6675 phase_name = event_name .split ("." )[- 1 ]
6776
@@ -73,11 +82,11 @@ async def add_handler(self, event, handler, urlPatterns=None, conn=None, task_st
7382
7483 self .intercepts [intercept ]["event_name" ] = event_name
7584 self .intercepts [intercept ]["handlers" ].append (handler )
76- task_status . started ( intercept )
77- await self . add_listener ( event = event , callback = self . handle_events )
85+ self . nursery . start_soon ( self . add_listener , event , self . handle_events )
86+ return intercept
7887
79- async def add_request_handler (self , handler , urlPatterns = None , conn = None , task_status = trio . TASK_STATUS_IGNORED ):
80- intercept = await self .add_handler (BeforeRequestSent , handler , urlPatterns , conn , task_status )
88+ async def add_request_handler (self , handler , urlPatterns = None ):
89+ intercept = await self .add_handler (BeforeRequestSent , handler , urlPatterns )
8190 return intercept
8291
8392 async def handle_events (self , event_params ):
0 commit comments