|
36 | 36 | from typing import TYPE_CHECKING |
37 | 37 |
|
38 | 38 | if TYPE_CHECKING: |
39 | | - from typing import Any, Awaitable, Callable, Container, Dict, Optional, Tuple, Union |
40 | | - |
| 39 | + from typing import ( |
| 40 | + Any, |
| 41 | + Awaitable, |
| 42 | + Callable, |
| 43 | + Container, |
| 44 | + Dict, |
| 45 | + Optional, |
| 46 | + Tuple, |
| 47 | + Union, |
| 48 | + Protocol, |
| 49 | + TypeVar, |
| 50 | + ) |
| 51 | + from types import CoroutineType |
41 | 52 | from sentry_sdk._types import Event, HttpStatusCodeRange |
42 | 53 |
|
43 | 54 | try: |
44 | 55 | import starlette # type: ignore |
45 | 56 | from starlette import __version__ as STARLETTE_VERSION |
46 | 57 | from starlette.applications import Starlette # type: ignore |
47 | | - from starlette.datastructures import UploadFile # type: ignore |
| 58 | + from starlette.datastructures import UploadFile, FormData # type: ignore |
48 | 59 | from starlette.middleware import Middleware # type: ignore |
49 | 60 | from starlette.middleware.authentication import ( # type: ignore |
50 | 61 | AuthenticationMiddleware, |
|
55 | 66 | except ImportError: |
56 | 67 | raise DidNotEnable("Starlette is not installed") |
57 | 68 |
|
| 69 | +if TYPE_CHECKING: |
| 70 | + from contextlib import AbstractAsyncContextManager |
| 71 | + |
| 72 | + T_co = TypeVar("T_co", covariant=True) |
| 73 | + |
| 74 | + class AwaitableOrContextManager( |
| 75 | + Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co] |
| 76 | + ): ... |
| 77 | + |
| 78 | + |
58 | 79 | try: |
59 | 80 | # Starlette 0.20 |
60 | 81 | from starlette.middleware.exceptions import ExceptionMiddleware # type: ignore |
@@ -426,29 +447,23 @@ def _patch_request(request): |
426 | 447 | _original_json = request.json |
427 | 448 | _original_form = request.form |
428 | 449 |
|
429 | | - def restore_original_methods(): |
430 | | - # type: () -> None |
431 | | - request.body = _original_body |
432 | | - request.json = _original_json |
433 | | - request.form = _original_form |
434 | | - |
435 | | - async def sentry_body(): |
436 | | - # type: () -> bytes |
| 450 | + @functools.wraps(_original_body) |
| 451 | + def sentry_body(): |
| 452 | + # type: () -> CoroutineType[Any, Any, bytes] |
437 | 453 | request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True |
438 | | - restore_original_methods() |
439 | | - return await _original_body() |
| 454 | + return _original_body() |
440 | 455 |
|
441 | | - async def sentry_json(): |
442 | | - # type: () -> Any |
| 456 | + @functools.wraps(_original_json) |
| 457 | + def sentry_json(): |
| 458 | + # type: () -> CoroutineType[Any, Any, Any] |
443 | 459 | request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True |
444 | | - restore_original_methods() |
445 | | - return await _original_json() |
| 460 | + return _original_json() |
446 | 461 |
|
447 | | - async def sentry_form(): |
448 | | - # type: () -> Any |
| 462 | + @functools.wraps(_original_form) |
| 463 | + def sentry_form(*args, **kwargs): |
| 464 | + # type: (*Any, **Any) -> AwaitableOrContextManager[FormData] |
449 | 465 | request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True |
450 | | - restore_original_methods() |
451 | | - return await _original_form() |
| 466 | + return _original_form(*args, **kwargs) |
452 | 467 |
|
453 | 468 | request.body = sentry_body |
454 | 469 | request.json = sentry_json |
|
0 commit comments