1515# specific language governing permissions and limitations
1616# under the License.
1717from collections import defaultdict
18- from contextlib import asynccontextmanager
19-
20- import trio
2118
2219from selenium .webdriver .common .bidi import network
2320from selenium .webdriver .common .bidi .browsing_context import Navigate
2421from selenium .webdriver .common .bidi .browsing_context import NavigateParameters
25- from selenium .webdriver .common .bidi .cdp import open_cdp
2622from selenium .webdriver .common .bidi .network import AddInterceptParameters
2723from selenium .webdriver .common .bidi .network import BeforeRequestSent
2824from selenium .webdriver .common .bidi .network import BeforeRequestSentParameters
3228
3329
3430class Network :
35- def __init__ (self , driver ):
31+ def __init__ (self , conn , driver ):
32+ self .intercepts = defaultdict (lambda : {"event" : None , "handlers" : []})
33+ self .callback_ids = {}
3634 self .driver = driver
37- self .listeners = {}
38- self .intercepts = defaultdict (lambda : {"event_name" : None , "handlers" : []})
39- self .bidi_network = None
40- self .conn = None
41- self .nursery = None
35+ self .conn = conn
36+ self .bidi_network = network .Network (self .conn )
4237
4338 self .remove_request_handler = self .remove_intercept
4439 self .clear_request_handlers = self .clear_intercepts
4540
46- @asynccontextmanager
47- async def set_context (self ):
48- ws_url = self .driver .caps .get ("webSocketUrl" )
49- async with open_cdp (ws_url ) as conn :
50- self .conn = conn
51- self .bidi_network = network .Network (self .conn )
52- async with trio .open_nursery () as nursery :
53- self .nursery = nursery
54- yield
55-
56- async def get (self , url , wait = "complete" ):
41+ def get (self , url , wait = "none" ):
5742 params = NavigateParameters (context = self .driver .current_window_handle , url = url , wait = wait )
58- await self .conn .execute (Navigate (params ).cmd ())
43+ self .conn .execute (Navigate (params ).cmd ())
5944
60- async def add_listener (self , event , callback ):
61- event_name = event .event_class
62- if event_name in self .listeners :
63- return
64- self .listeners [event_name ] = self .conn .listen (event )
65- try :
66- async for event in self .listeners [event_name ]:
67- request_data = event .params
68- if request_data .isBlocked :
69- await callback (request_data )
70- except trio .ClosedResourceError :
71- pass
72-
73- async def add_handler (self , event , handler , urlPatterns = None ):
45+ def add_handler (self , event , handler , urlPatterns = None ):
7446 event_name = event .event_class
7547 phase_name = event_name .split ("." )[- 1 ]
7648
77- await self .conn .execute (session_subscribe (event_name ))
49+ self .conn .execute (session_subscribe (event_name ))
7850
7951 params = AddInterceptParameters (phases = [phase_name ], urlPatterns = urlPatterns )
80- result = await self .bidi_network .add_intercept (params )
52+ result = self .bidi_network .add_intercept (params )
8153 intercept = result ["intercept" ]
8254
83- self .intercepts [intercept ]["event_name " ] = event_name
55+ self .intercepts [intercept ]["event " ] = event
8456 self .intercepts [intercept ]["handlers" ].append (handler )
85- self .nursery .start_soon (self .add_listener , event , self .handle_events )
57+ if not self .callback_ids .get (event_name ):
58+ self .callback_ids [event_name ] = self .conn .add_callback (event , self .handle_events )
8659 return intercept
8760
88- async def add_request_handler (self , handler , urlPatterns = None ):
89- intercept = await self .add_handler (BeforeRequestSent , handler , urlPatterns )
61+ def add_request_handler (self , handler , urlPatterns = None ):
62+ intercept = self .add_handler (BeforeRequestSent , handler , urlPatterns )
9063 return intercept
9164
92- async def handle_events (self , event_params ):
93- if isinstance (event_params , BeforeRequestSentParameters ):
65+ def handle_events (self , event ):
66+ event_params = event .params
67+ if isinstance (event_params , BeforeRequestSentParameters ) and event_params .isBlocked :
9468 json = self .handle_requests (event_params )
9569 params = ContinueRequestParameters (** json )
96- await self .bidi_network .continue_request (params )
70+ self .bidi_network .continue_request (params )
9771
9872 def handle_requests (self , params ):
9973 request = params .request
@@ -102,24 +76,19 @@ def handle_requests(self, params):
10276 request = handler (request )
10377 return request
10478
105- async def remove_listener (self , event_name ):
106- listener = self .listeners .pop (event_name )
107- listener .close ()
108-
109- async def remove_intercept (self , intercept ):
110- await self .bidi_network .remove_intercept (
79+ def remove_intercept (self , intercept ):
80+ self .bidi_network .remove_intercept (
11181 params = network .RemoveInterceptParameters (intercept ),
11282 )
113- event_name = self .intercepts .pop (intercept )["event_name" ]
114- remaining = [i for i in self .intercepts .values () if i ["event_name" ] == event_name ]
83+
84+ event = self .intercepts .pop (intercept )["event" ]
85+ event_name = event .event_class
86+
87+ remaining = [i for i in self .intercepts .values () if i ["event" ].event_class == event_name ]
11588 if len (remaining ) == 0 :
116- await self .remove_listener ( event_name )
117- await self .conn .execute ( session_unsubscribe ( event_name ) )
89+ self .conn . execute ( session_unsubscribe ( event_name ) )
90+ self .conn .remove_callback ( event , self . callback_ids [ event_name ] )
11891
119- async def clear_intercepts (self ):
92+ def clear_intercepts (self ):
12093 for intercept in self .intercepts :
121- await self .remove_intercept (intercept )
122-
123- async def disable_cache (self ):
124- # Bidi 'network.setCacheBehavior' is not implemented in v130
125- self .driver .execute_cdp_cmd ("Network.setCacheDisabled" , {"cacheDisabled" : True })
94+ self .remove_intercept (intercept )
0 commit comments