|
| 1 | +from typing import TYPE_CHECKING |
| 2 | + |
| 3 | +from pydantic import BaseModel # type: ignore |
| 4 | +from sentry_sdk.consts import OP |
| 5 | +from sentry_sdk.hub import Hub, _should_send_default_pii |
| 6 | +from sentry_sdk.integrations import DidNotEnable, Integration |
| 7 | +from sentry_sdk.integrations.asgi import SentryAsgiMiddleware |
| 8 | +from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_ROUTE |
| 9 | +from sentry_sdk.utils import event_from_exception, transaction_from_function |
| 10 | + |
| 11 | +try: |
| 12 | + from starlite import Request, Starlite, State # type: ignore |
| 13 | + from starlite.handlers.base import BaseRouteHandler # type: ignore |
| 14 | + from starlite.middleware import DefineMiddleware # type: ignore |
| 15 | + from starlite.plugins.base import get_plugin_for_value # type: ignore |
| 16 | + from starlite.routes.http import HTTPRoute # type: ignore |
| 17 | + from starlite.utils import ConnectionDataExtractor, is_async_callable, Ref # type: ignore |
| 18 | + |
| 19 | + if TYPE_CHECKING: |
| 20 | + from typing import Any, Dict, List, Optional, Union |
| 21 | + from starlite.types import ( # type: ignore |
| 22 | + ASGIApp, |
| 23 | + HTTPReceiveMessage, |
| 24 | + HTTPScope, |
| 25 | + Message, |
| 26 | + Middleware, |
| 27 | + Receive, |
| 28 | + Scope, |
| 29 | + Send, |
| 30 | + WebSocketReceiveMessage, |
| 31 | + ) |
| 32 | + from starlite import MiddlewareProtocol |
| 33 | + from sentry_sdk._types import Event |
| 34 | +except ImportError: |
| 35 | + raise DidNotEnable("Starlite is not installed") |
| 36 | + |
| 37 | + |
| 38 | +_DEFAULT_TRANSACTION_NAME = "generic Starlite request" |
| 39 | + |
| 40 | + |
| 41 | +class SentryStarliteASGIMiddleware(SentryAsgiMiddleware): |
| 42 | + def __init__(self, app: "ASGIApp"): |
| 43 | + super().__init__( |
| 44 | + app=app, |
| 45 | + unsafe_context_data=False, |
| 46 | + transaction_style="endpoint", |
| 47 | + mechanism_type="asgi", |
| 48 | + ) |
| 49 | + |
| 50 | + |
| 51 | +class StarliteIntegration(Integration): |
| 52 | + identifier = "starlite" |
| 53 | + |
| 54 | + @staticmethod |
| 55 | + def setup_once() -> None: |
| 56 | + patch_app_init() |
| 57 | + patch_middlewares() |
| 58 | + patch_http_route_handle() |
| 59 | + |
| 60 | + |
| 61 | +def patch_app_init() -> None: |
| 62 | + """ |
| 63 | + Replaces the Starlite class's `__init__` function in order to inject `after_exception` handlers and set the |
| 64 | + `SentryStarliteASGIMiddleware` as the outmost middleware in the stack. |
| 65 | + See: |
| 66 | + - https://starlite-api.github.io/starlite/usage/0-the-starlite-app/5-application-hooks/#after-exception |
| 67 | + - https://starlite-api.github.io/starlite/usage/7-middleware/0-middleware-intro/ |
| 68 | + """ |
| 69 | + old__init__ = Starlite.__init__ |
| 70 | + |
| 71 | + def injection_wrapper(self: "Starlite", *args: "Any", **kwargs: "Any") -> None: |
| 72 | + |
| 73 | + after_exception = kwargs.pop("after_exception", []) |
| 74 | + kwargs.update( |
| 75 | + after_exception=[ |
| 76 | + exception_handler, |
| 77 | + *( |
| 78 | + after_exception |
| 79 | + if isinstance(after_exception, list) |
| 80 | + else [after_exception] |
| 81 | + ), |
| 82 | + ] |
| 83 | + ) |
| 84 | + |
| 85 | + SentryStarliteASGIMiddleware.__call__ = SentryStarliteASGIMiddleware._run_asgi3 |
| 86 | + middleware = kwargs.pop("middleware", None) or [] |
| 87 | + kwargs["middleware"] = [SentryStarliteASGIMiddleware, *middleware] |
| 88 | + old__init__(self, *args, **kwargs) |
| 89 | + |
| 90 | + Starlite.__init__ = injection_wrapper |
| 91 | + |
| 92 | + |
| 93 | +def patch_middlewares() -> None: |
| 94 | + old__resolve_middleware_stack = BaseRouteHandler.resolve_middleware |
| 95 | + |
| 96 | + def resolve_middleware_wrapper(self: "Any") -> "List[Middleware]": |
| 97 | + return [ |
| 98 | + enable_span_for_middleware(middleware) |
| 99 | + for middleware in old__resolve_middleware_stack(self) |
| 100 | + ] |
| 101 | + |
| 102 | + BaseRouteHandler.resolve_middleware = resolve_middleware_wrapper |
| 103 | + |
| 104 | + |
| 105 | +def enable_span_for_middleware(middleware: "Middleware") -> "Middleware": |
| 106 | + if ( |
| 107 | + not hasattr(middleware, "__call__") # noqa: B004 |
| 108 | + or middleware is SentryStarliteASGIMiddleware |
| 109 | + ): |
| 110 | + return middleware |
| 111 | + |
| 112 | + if isinstance(middleware, DefineMiddleware): |
| 113 | + old_call: "ASGIApp" = middleware.middleware.__call__ |
| 114 | + else: |
| 115 | + old_call = middleware.__call__ |
| 116 | + |
| 117 | + async def _create_span_call( |
| 118 | + self: "MiddlewareProtocol", scope: "Scope", receive: "Receive", send: "Send" |
| 119 | + ) -> None: |
| 120 | + hub = Hub.current |
| 121 | + integration = hub.get_integration(StarliteIntegration) |
| 122 | + if integration is not None: |
| 123 | + middleware_name = self.__class__.__name__ |
| 124 | + with hub.start_span( |
| 125 | + op=OP.MIDDLEWARE_STARLITE, description=middleware_name |
| 126 | + ) as middleware_span: |
| 127 | + middleware_span.set_tag("starlite.middleware_name", middleware_name) |
| 128 | + |
| 129 | + # Creating spans for the "receive" callback |
| 130 | + async def _sentry_receive( |
| 131 | + *args: "Any", **kwargs: "Any" |
| 132 | + ) -> "Union[HTTPReceiveMessage, WebSocketReceiveMessage]": |
| 133 | + hub = Hub.current |
| 134 | + with hub.start_span( |
| 135 | + op=OP.MIDDLEWARE_STARLITE_RECEIVE, |
| 136 | + description=getattr(receive, "__qualname__", str(receive)), |
| 137 | + ) as span: |
| 138 | + span.set_tag("starlite.middleware_name", middleware_name) |
| 139 | + return await receive(*args, **kwargs) |
| 140 | + |
| 141 | + receive_name = getattr(receive, "__name__", str(receive)) |
| 142 | + receive_patched = receive_name == "_sentry_receive" |
| 143 | + new_receive = _sentry_receive if not receive_patched else receive |
| 144 | + |
| 145 | + # Creating spans for the "send" callback |
| 146 | + async def _sentry_send(message: "Message") -> None: |
| 147 | + hub = Hub.current |
| 148 | + with hub.start_span( |
| 149 | + op=OP.MIDDLEWARE_STARLITE_SEND, |
| 150 | + description=getattr(send, "__qualname__", str(send)), |
| 151 | + ) as span: |
| 152 | + span.set_tag("starlite.middleware_name", middleware_name) |
| 153 | + return await send(message) |
| 154 | + |
| 155 | + send_name = getattr(send, "__name__", str(send)) |
| 156 | + send_patched = send_name == "_sentry_send" |
| 157 | + new_send = _sentry_send if not send_patched else send |
| 158 | + |
| 159 | + return await old_call(self, scope, new_receive, new_send) |
| 160 | + else: |
| 161 | + return await old_call(self, scope, receive, send) |
| 162 | + |
| 163 | + not_yet_patched = old_call.__name__ not in ["_create_span_call"] |
| 164 | + |
| 165 | + if not_yet_patched: |
| 166 | + if isinstance(middleware, DefineMiddleware): |
| 167 | + middleware.middleware.__call__ = _create_span_call |
| 168 | + else: |
| 169 | + middleware.__call__ = _create_span_call |
| 170 | + |
| 171 | + return middleware |
| 172 | + |
| 173 | + |
| 174 | +def patch_http_route_handle() -> None: |
| 175 | + old_handle = HTTPRoute.handle |
| 176 | + |
| 177 | + async def handle_wrapper( |
| 178 | + self: "HTTPRoute", scope: "HTTPScope", receive: "Receive", send: "Send" |
| 179 | + ) -> None: |
| 180 | + hub = Hub.current |
| 181 | + integration: StarliteIntegration = hub.get_integration(StarliteIntegration) |
| 182 | + if integration is None: |
| 183 | + return await old_handle(self, scope, receive, send) |
| 184 | + |
| 185 | + with hub.configure_scope() as sentry_scope: |
| 186 | + request: "Request[Any, Any]" = scope["app"].request_class( |
| 187 | + scope=scope, receive=receive, send=send |
| 188 | + ) |
| 189 | + extracted_request_data = ConnectionDataExtractor( |
| 190 | + parse_body=True, parse_query=True |
| 191 | + )(request) |
| 192 | + body = extracted_request_data.pop("body") |
| 193 | + |
| 194 | + request_data = await body |
| 195 | + |
| 196 | + def event_processor(event: "Event", _: "Dict[str, Any]") -> "Event": |
| 197 | + route_handler = scope.get("route_handler") |
| 198 | + |
| 199 | + request_info = event.get("request", {}) |
| 200 | + request_info["content_length"] = len(scope.get("_body", b"")) |
| 201 | + if _should_send_default_pii(): |
| 202 | + request_info["cookies"] = extracted_request_data["cookies"] |
| 203 | + if request_data is not None: |
| 204 | + request_info["data"] = request_data |
| 205 | + |
| 206 | + func = None |
| 207 | + if route_handler.name is not None: |
| 208 | + tx_name = route_handler.name |
| 209 | + elif isinstance(route_handler.fn, Ref): |
| 210 | + func = route_handler.fn.value |
| 211 | + else: |
| 212 | + func = route_handler.fn |
| 213 | + if func is not None: |
| 214 | + tx_name = transaction_from_function(func) |
| 215 | + |
| 216 | + tx_info = {"source": SOURCE_FOR_STYLE["endpoint"]} |
| 217 | + |
| 218 | + if not tx_name: |
| 219 | + tx_name = _DEFAULT_TRANSACTION_NAME |
| 220 | + tx_info = {"source": TRANSACTION_SOURCE_ROUTE} |
| 221 | + |
| 222 | + event.update( |
| 223 | + request=request_info, transaction=tx_name, transaction_info=tx_info |
| 224 | + ) |
| 225 | + return event |
| 226 | + |
| 227 | + sentry_scope._name = StarliteIntegration.identifier |
| 228 | + sentry_scope.add_event_processor(event_processor) |
| 229 | + |
| 230 | + return await old_handle(self, scope, receive, send) |
| 231 | + |
| 232 | + HTTPRoute.handle = handle_wrapper |
| 233 | + |
| 234 | + |
| 235 | +def retrieve_user_from_scope(scope: "Scope") -> "Optional[Dict[str, Any]]": |
| 236 | + scope_user = scope.get("user", {}) |
| 237 | + if not scope_user: |
| 238 | + return None |
| 239 | + if isinstance(scope_user, dict): |
| 240 | + return scope_user |
| 241 | + if isinstance(scope_user, BaseModel): |
| 242 | + return scope_user.dict() |
| 243 | + if hasattr(scope_user, "asdict"): # dataclasses |
| 244 | + return scope_user.asdict() |
| 245 | + |
| 246 | + plugin = get_plugin_for_value(scope_user) |
| 247 | + if plugin and not is_async_callable(plugin.to_dict): |
| 248 | + return plugin.to_dict(scope_user) |
| 249 | + |
| 250 | + return None |
| 251 | + |
| 252 | + |
| 253 | +def exception_handler(exc: Exception, scope: "Scope", _: "State") -> None: |
| 254 | + hub = Hub.current |
| 255 | + if hub.get_integration(StarliteIntegration) is None: |
| 256 | + return |
| 257 | + |
| 258 | + user_info: "Optional[Dict[str, Any]]" = None |
| 259 | + if _should_send_default_pii(): |
| 260 | + user_info = retrieve_user_from_scope(scope) |
| 261 | + if user_info and isinstance(user_info, dict): |
| 262 | + with hub.configure_scope() as sentry_scope: |
| 263 | + sentry_scope.set_user(user_info) |
| 264 | + |
| 265 | + event, hint = event_from_exception( |
| 266 | + exc, |
| 267 | + client_options=hub.client.options if hub.client else None, |
| 268 | + mechanism={"type": StarliteIntegration.identifier, "handled": False}, |
| 269 | + ) |
| 270 | + |
| 271 | + hub.capture_event(event, hint=hint) |
0 commit comments