Skip to content

Commit 5d33195

Browse files
authored
chore: rework apiName api call handling in connection (#1361)
1 parent 724f4e0 commit 5d33195

File tree

11 files changed

+1698
-2858
lines changed

11 files changed

+1698
-2858
lines changed

playwright/_impl/_async_base.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
# limitations under the License.
1414

1515
import asyncio
16-
import traceback
1716
from types import TracebackType
18-
from typing import Any, Awaitable, Callable, Generic, Type, TypeVar
17+
from typing import Any, Callable, Generic, Type, TypeVar
1918

2019
from playwright._impl._impl_to_api_mapping import ImplToApiMapping, ImplWrapper
2120

@@ -62,12 +61,6 @@ def __init__(self, impl_obj: Any) -> None:
6261
def __str__(self) -> str:
6362
return self._impl_obj.__str__()
6463

65-
def _async(self, api_name: str, coro: Awaitable) -> Any:
66-
task = asyncio.current_task()
67-
setattr(task, "__pw_api_name__", api_name)
68-
setattr(task, "__pw_stack_trace__", traceback.extract_stack())
69-
return coro
70-
7164
def _wrap_handler(self, handler: Any) -> Callable[..., None]:
7265
if callable(handler):
7366
return mapping.wrap_handler(handler)

playwright/_impl/_connection.py

Lines changed: 82 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import contextvars
17+
import inspect
1618
import sys
1719
import traceback
1820
from pathlib import Path
19-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
2022

2123
from greenlet import greenlet
2224
from pyee import AsyncIOEventEmitter, EventEmitter
2325

26+
import playwright
2427
from playwright._impl._helper import ParsedMessagePayload, parse_error
2528
from playwright._impl._transport import Transport
2629

@@ -36,10 +39,21 @@ def __init__(self, connection: "Connection", guid: str) -> None:
3639
self._object: Optional[ChannelOwner] = None
3740

3841
async def send(self, method: str, params: Dict = None) -> Any:
39-
return await self.inner_send(method, params, False)
42+
return await self._connection.wrap_api_call(
43+
lambda: self.inner_send(method, params, False)
44+
)
4045

4146
async def send_return_as_dict(self, method: str, params: Dict = None) -> Any:
42-
return await self.inner_send(method, params, True)
47+
return await self._connection.wrap_api_call(
48+
lambda: self.inner_send(method, params, True)
49+
)
50+
51+
def send_no_reply(self, method: str, params: Dict = None) -> None:
52+
self._connection.wrap_api_call(
53+
lambda: self._connection._send_message_to_server(
54+
self._guid, method, {} if params is None else params
55+
)
56+
)
4357

4458
async def inner_send(
4559
self, method: str, params: Optional[Dict], return_as_dict: bool
@@ -74,11 +88,6 @@ async def inner_send(
7488
key = next(iter(result))
7589
return result[key]
7690

77-
def send_no_reply(self, method: str, params: Dict = None) -> None:
78-
if params is None:
79-
params = {}
80-
self._connection._send_message_to_server(self._guid, method, params)
81-
8291

8392
class ChannelOwner(AsyncIOEventEmitter):
8493
def __init__(
@@ -122,7 +131,7 @@ def _dispose(self) -> None:
122131

123132
class ProtocolCallback:
124133
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
125-
self.stack_trace: traceback.StackSummary = traceback.StackSummary()
134+
self.stack_trace: traceback.StackSummary
126135
self.future = loop.create_future()
127136
# The outer task can get cancelled by the user, this forwards the cancellation to the inner task.
128137
current_task = asyncio.current_task()
@@ -181,6 +190,9 @@ def __init__(
181190
self._error: Optional[BaseException] = None
182191
self.is_remote = False
183192
self._init_task: Optional[asyncio.Task] = None
193+
self._api_zone: contextvars.ContextVar[Optional[Dict]] = contextvars.ContextVar(
194+
"ApiZone", default=None
195+
)
184196

185197
def mark_as_remote(self) -> None:
186198
self.is_remote = True
@@ -230,22 +242,17 @@ def _send_message_to_server(
230242
id = self._last_id
231243
callback = ProtocolCallback(self._loop)
232244
task = asyncio.current_task(self._loop)
233-
stack_trace: Optional[traceback.StackSummary] = getattr(
234-
task, "__pw_stack_trace__", None
245+
callback.stack_trace = cast(
246+
traceback.StackSummary,
247+
getattr(task, "__pw_stack_trace__", traceback.extract_stack()),
235248
)
236-
callback.stack_trace = stack_trace or traceback.extract_stack()
237249
self._callbacks[id] = callback
238-
metadata = {"stack": serialize_call_stack(callback.stack_trace)}
239-
api_name = getattr(task, "__pw_api_name__", None)
240-
if api_name:
241-
metadata["apiName"] = api_name
242-
243250
message = {
244251
"id": id,
245252
"guid": guid,
246253
"method": method,
247254
"params": self._replace_channels_with_guids(params),
248-
"metadata": metadata,
255+
"metadata": self._api_zone.get(),
249256
}
250257
self._transport.send(message)
251258
self._callbacks[id] = callback
@@ -337,6 +344,27 @@ def _replace_guids_with_channels(self, payload: Any) -> Any:
337344
return result
338345
return payload
339346

347+
def wrap_api_call(self, cb: Callable[[], Any], is_internal: bool = False) -> Any:
348+
if self._api_zone.get():
349+
return cb()
350+
task = asyncio.current_task(self._loop)
351+
st: List[inspect.FrameInfo] = getattr(task, "__pw_stack__", inspect.stack())
352+
metadata = _extract_metadata_from_stack(st, is_internal)
353+
if metadata:
354+
self._api_zone.set(metadata)
355+
result = cb()
356+
357+
async def _() -> None:
358+
try:
359+
return await result
360+
finally:
361+
self._api_zone.set(None)
362+
363+
if asyncio.iscoroutine(result):
364+
return _()
365+
self._api_zone.set(None)
366+
return result
367+
340368

341369
def from_channel(channel: Channel) -> Any:
342370
return channel._object
@@ -346,13 +374,40 @@ def from_nullable_channel(channel: Optional[Channel]) -> Optional[Any]:
346374
return channel._object if channel else None
347375

348376

349-
def serialize_call_stack(stack_trace: traceback.StackSummary) -> List[Dict]:
377+
def _extract_metadata_from_stack(
378+
st: List[inspect.FrameInfo], is_internal: bool
379+
) -> Optional[Dict]:
380+
playwright_module_path = str(Path(playwright.__file__).parents[0])
381+
last_internal_api_name = ""
382+
api_name = ""
350383
stack: List[Dict] = []
351-
for frame in stack_trace:
352-
if "_generated.py" in frame.filename:
353-
break
354-
stack.append(
355-
{"file": frame.filename, "line": frame.lineno, "function": frame.name}
356-
)
357-
stack.reverse()
358-
return stack
384+
for frame in st:
385+
is_playwright_internal = frame.filename.startswith(playwright_module_path)
386+
387+
method_name = ""
388+
if "self" in frame[0].f_locals:
389+
method_name = frame[0].f_locals["self"].__class__.__name__ + "."
390+
method_name += frame[0].f_code.co_name
391+
392+
if not is_playwright_internal:
393+
stack.append(
394+
{
395+
"file": frame.filename,
396+
"line": frame.lineno,
397+
"function": method_name,
398+
}
399+
)
400+
if is_playwright_internal:
401+
last_internal_api_name = method_name
402+
elif last_internal_api_name:
403+
api_name = last_internal_api_name
404+
last_internal_api_name = ""
405+
if not api_name:
406+
api_name = last_internal_api_name
407+
if api_name:
408+
return {
409+
"apiName": api_name,
410+
"stack": stack,
411+
"isInternal": is_internal,
412+
}
413+
return None

playwright/_impl/_sync_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import inspect
1617
import traceback
1718
from types import TracebackType
1819
from typing import Any, Awaitable, Callable, Dict, Generic, List, Type, TypeVar, cast
@@ -74,11 +75,11 @@ def __init__(self, impl_obj: Any) -> None:
7475
def __str__(self) -> str:
7576
return self._impl_obj.__str__()
7677

77-
def _sync(self, api_name: str, coro: Awaitable) -> Any:
78+
def _sync(self, coro: Awaitable) -> Any:
7879
__tracebackhide__ = True
7980
g_self = greenlet.getcurrent()
8081
task = self._loop.create_task(coro)
81-
setattr(task, "__pw_api_name__", api_name)
82+
setattr(task, "__pw_stack__", inspect.stack())
8283
setattr(task, "__pw_stack_trace__", traceback.extract_stack())
8384

8485
task.add_done_callback(lambda _: g_self.switch())
@@ -147,7 +148,7 @@ def __exit__(
147148
self,
148149
exc_type: Type[BaseException],
149150
exc_val: BaseException,
150-
traceback: TracebackType,
151+
_traceback: TracebackType,
151152
) -> None:
152153
self.close()
153154

playwright/_impl/_wait_helper.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,19 @@ def _wait_for_event_info_before(self, wait_id: str, event: str) -> None:
4848
)
4949

5050
def _wait_for_event_info_after(self, wait_id: str, error: Exception = None) -> None:
51-
try:
52-
info = {
53-
"waitId": wait_id,
54-
"phase": "after",
55-
}
56-
if error:
57-
info["error"] = str(error)
58-
self._channel.send_no_reply(
51+
self._channel._connection.wrap_api_call(
52+
lambda: self._channel.send_no_reply(
5953
"waitForEventInfo",
6054
{
61-
"info": info,
55+
"info": {
56+
"waitId": wait_id,
57+
"phase": "after",
58+
**({"error": str(error)} if error else {}),
59+
},
6260
},
63-
)
64-
except Exception:
65-
pass
61+
),
62+
True,
63+
)
6664

6765
def reject_on_event(
6866
self,
@@ -129,15 +127,18 @@ def result(self) -> asyncio.Future:
129127
def log(self, message: str) -> None:
130128
self._logs.append(message)
131129
try:
132-
self._channel.send_no_reply(
133-
"waitForEventInfo",
134-
{
135-
"info": {
136-
"waitId": self._wait_id,
137-
"phase": "log",
138-
"message": message,
130+
self._channel._connection.wrap_api_call(
131+
lambda: self._channel.send_no_reply(
132+
"waitForEventInfo",
133+
{
134+
"info": {
135+
"waitId": self._wait_id,
136+
"phase": "log",
137+
"message": message,
138+
},
139139
},
140-
},
140+
),
141+
True,
141142
)
142143
except Exception:
143144
pass

0 commit comments

Comments
 (0)