Skip to content

Commit 28095bd

Browse files
improve typing and simplify patching
1 parent 317e46f commit 28095bd

File tree

1 file changed

+36
-21
lines changed

1 file changed

+36
-21
lines changed

sentry_sdk/integrations/starlette.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,26 @@
3636
from typing import TYPE_CHECKING
3737

3838
if TYPE_CHECKING:
39-
from typing import Any, Awaitable, Callable, Container, Dict, Optional, Tuple, Union
40-
39+
from typing import (
40+
Any,
41+
Awaitable,
42+
Callable,
43+
Container,
44+
Dict,
45+
Optional,
46+
Tuple,
47+
Union,
48+
Protocol,
49+
TypeVar,
50+
)
51+
from types import CoroutineType
4152
from sentry_sdk._types import Event, HttpStatusCodeRange
4253

4354
try:
4455
import starlette # type: ignore
4556
from starlette import __version__ as STARLETTE_VERSION
4657
from starlette.applications import Starlette # type: ignore
47-
from starlette.datastructures import UploadFile # type: ignore
58+
from starlette.datastructures import UploadFile, FormData # type: ignore
4859
from starlette.middleware import Middleware # type: ignore
4960
from starlette.middleware.authentication import ( # type: ignore
5061
AuthenticationMiddleware,
@@ -55,6 +66,16 @@
5566
except ImportError:
5667
raise DidNotEnable("Starlette is not installed")
5768

69+
if TYPE_CHECKING:
70+
from contextlib import AbstractAsyncContextManager
71+
72+
T_co = TypeVar("T_co", covariant=True)
73+
74+
class AwaitableOrContextManager(
75+
Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]
76+
): ...
77+
78+
5879
try:
5980
# Starlette 0.20
6081
from starlette.middleware.exceptions import ExceptionMiddleware # type: ignore
@@ -426,29 +447,23 @@ def _patch_request(request):
426447
_original_json = request.json
427448
_original_form = request.form
428449

429-
def restore_original_methods():
430-
# type: () -> None
431-
request.body = _original_body
432-
request.json = _original_json
433-
request.form = _original_form
434-
435-
async def sentry_body():
436-
# type: () -> bytes
450+
@functools.wraps(_original_body)
451+
def sentry_body():
452+
# type: () -> CoroutineType[Any, Any, bytes]
437453
request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True
438-
restore_original_methods()
439-
return await _original_body()
454+
return _original_body()
440455

441-
async def sentry_json():
442-
# type: () -> Any
456+
@functools.wraps(_original_json)
457+
def sentry_json():
458+
# type: () -> CoroutineType[Any, Any, Any]
443459
request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True
444-
restore_original_methods()
445-
return await _original_json()
460+
return _original_json()
446461

447-
async def sentry_form():
448-
# type: () -> Any
462+
@functools.wraps(_original_form)
463+
def sentry_form(*args, **kwargs):
464+
# type: (*Any, **Any) -> AwaitableOrContextManager[FormData]
449465
request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True
450-
restore_original_methods()
451-
return await _original_form()
466+
return _original_form(*args, **kwargs)
452467

453468
request.body = sentry_body
454469
request.json = sentry_json

0 commit comments

Comments
 (0)