diff --git a/sentry_sdk/integrations/asgi.py b/sentry_sdk/integrations/asgi.py index dde8128a33..2526300c65 100644 --- a/sentry_sdk/integrations/asgi.py +++ b/sentry_sdk/integrations/asgi.py @@ -36,6 +36,7 @@ _get_installed_modules, ) from sentry_sdk.tracing import Transaction +from sentry_sdk.integrations._wsgi_common import request_body_within_bounds from typing import TYPE_CHECKING @@ -178,6 +179,21 @@ async def _run_asgi3(self, scope, receive, send): # type: (Any, Any, Any) -> Any return await self._run_app(scope, receive, send, asgi_version=3) + async def _eagerly_receive_body(self, receive): + body = b"" + buffered_messages = [] + while True: + msg = await receive() + buffered_messages.append(msg) + + if "body" in msg: + body += msg["body"] + + if not msg.get("more_body", False): + break + + return body, buffered_messages + async def _run_app(self, scope, receive, send, asgi_version): # type: (Any, Any, Any, int) -> Any is_recursive_asgi_middleware = _asgi_middleware_applied.get(False) @@ -213,10 +229,11 @@ async def _run_app(self, scope, receive, send, asgi_version): method = scope.get("method", "").upper() transaction = None + headers = _get_headers(scope) if ty in ("http", "websocket"): if ty == "websocket" or method in self.http_methods_to_capture: transaction = continue_trace( - _get_headers(scope), + headers, op="{}.server".format(ty), name=transaction_name, source=transaction_source, @@ -241,6 +258,31 @@ async def _run_app(self, scope, receive, send, asgi_version): if transaction is not None else nullcontext() ): + client = sentry_sdk.get_client() + + read_request_body = ( + "content-length" in headers + and request_body_within_bounds( + client, int(headers["content-length"]) + ) + ) + scope.setdefault("state", {})[ + "sentry_sdk.content-length" + ] = None + + buffered_messages = [] + if read_request_body: + body, buffered_messages = await self._eagerly_receive_body( + receive + ) + scope["state"]["sentry_sdk.raw_body"] = body + + # Replay wrapper: first return buffered messages, then delegate to original receive + async def _sentry_replay_receive(): + if buffered_messages: + return buffered_messages.pop(0) + return await receive() + try: async def _sentry_wrapped_send(event): @@ -255,10 +297,18 @@ async def _sentry_wrapped_send(event): return await send(event) - if asgi_version == 2: + if asgi_version == 2 and read_request_body: + return await self.app(scope)( + _sentry_replay_receive, _sentry_wrapped_send + ) + elif asgi_version == 2: return await self.app(scope)( receive, _sentry_wrapped_send ) + elif read_request_body: + return await self.app( + scope, _sentry_replay_receive, _sentry_wrapped_send + ) else: return await self.app( scope, receive, _sentry_wrapped_send diff --git a/sentry_sdk/integrations/fastapi.py b/sentry_sdk/integrations/fastapi.py index 1473cbcab7..fcef780e90 100644 --- a/sentry_sdk/integrations/fastapi.py +++ b/sentry_sdk/integrations/fastapi.py @@ -18,6 +18,7 @@ from sentry_sdk.integrations.starlette import ( StarletteIntegration, StarletteRequestExtractor, + _patch_request, ) except DidNotEnable: raise DidNotEnable("Starlette is not installed") @@ -103,13 +104,12 @@ async def _sentry_app(*args, **kwargs): return await old_app(*args, **kwargs) request = args[0] + _patch_request(request) _set_transaction_name_and_source( sentry_sdk.get_current_scope(), integration.transaction_style, request ) sentry_scope = sentry_sdk.get_isolation_scope() - extractor = StarletteRequestExtractor(request) - info = await extractor.extract_request_info() def _make_request_event_processor(req, integration): # type: (Any, Any) -> Callable[[Event, Dict[str, Any]], Event] @@ -117,6 +117,9 @@ def event_processor(event, hint): # type: (Event, Dict[str, Any]) -> Event # Extract information from request + extractor = StarletteRequestExtractor(request) + info = extractor.extract_request_info(req.scope) + request_info = event.get("request", {}) if info: if "cookies" in info and should_send_default_pii(): diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index c7ce40618b..4e81b31335 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -4,6 +4,11 @@ from collections.abc import Set from copy import deepcopy from json import JSONDecodeError +import json +from urllib.parse import parse_qsl, unquote_plus +from dataclasses import dataclass, field +from enum import Enum +from tempfile import SpooledTemporaryFile import sentry_sdk from sentry_sdk.consts import OP @@ -39,12 +44,11 @@ from typing import Any, Awaitable, Callable, Container, Dict, Optional, Tuple, Union from sentry_sdk._types import Event, HttpStatusCodeRange - try: import starlette # type: ignore from starlette import __version__ as STARLETTE_VERSION from starlette.applications import Starlette # type: ignore - from starlette.datastructures import UploadFile # type: ignore + from starlette.datastructures import UploadFile, FormData, Headers # type: ignore from starlette.middleware import Middleware # type: ignore from starlette.middleware.authentication import ( # type: ignore AuthenticationMiddleware, @@ -67,12 +71,22 @@ try: # python-multipart 0.0.13 and later import python_multipart as multipart # type: ignore + from python_multipart.multipart import ( + parse_options_header, + MultipartCallbacks, + QuerystringCallbacks, + ) except ImportError: # python-multipart 0.0.12 and earlier import multipart # type: ignore + from multipart.multipart import ( + parse_options_header, + MultipartCallbacks, + QuerystringCallbacks, + ) except ImportError: multipart = None - + parse_options_header = None _DEFAULT_TRANSACTION_NAME = "generic Starlette request" @@ -422,6 +436,36 @@ def _is_async_callable(obj): ) +def _patch_request(request): + _original_body = request.body + _original_json = request.json + _original_form = request.form + + def restore_original_methods(): + request.body = _original_body + request.json = _original_json + request.form = _original_form + + async def sentry_body(): + request.scope["state"]["sentry_sdk.body"] = await _original_body() + restore_original_methods() + return request.scope["state"]["sentry_sdk.body"] + + async def sentry_json(): + request.scope["state"]["sentry_sdk.json"] = await _original_json() + restore_original_methods() + return request.scope["state"]["sentry_sdk.json"] + + async def sentry_form(): + request.scope["state"]["sentry_sdk.form"] = await _original_form() + restore_original_methods() + return request.scope["state"]["sentry_sdk.form"] + + request.body = sentry_body + request.json = sentry_json + request.form = sentry_form + + def patch_request_response(): # type: () -> None old_request_response = starlette.routing.request_response @@ -442,6 +486,7 @@ async def _sentry_async_func(*args, **kwargs): return await old_func(*args, **kwargs) request = args[0] + _patch_request(request) _set_transaction_name_and_source( sentry_sdk.get_current_scope(), @@ -450,8 +495,6 @@ async def _sentry_async_func(*args, **kwargs): ) sentry_scope = sentry_sdk.get_isolation_scope() - extractor = StarletteRequestExtractor(request) - info = await extractor.extract_request_info() def _make_request_event_processor(req, integration): # type: (Any, Any) -> Callable[[Event, dict[str, Any]], Event] @@ -459,6 +502,9 @@ def event_processor(event, hint): # type: (Event, Dict[str, Any]) -> Event # Add info from request to event + extractor = StarletteRequestExtractor(request) + info = extractor.extract_request_info(req.scope) + request_info = event.get("request", {}) if info: if "cookies" in info: @@ -580,6 +626,286 @@ def add_sentry_trace_meta(request): Jinja2Templates.__init__ = _sentry_jinja2templates_init +def _is_form_data_encoded(ct): + # type: (Optional[str]) -> bool + mt = (ct or "").split(";", 1)[0] + return mt == "multipart/form-data" + + +def _is_form_urlencoded(ct): + # type: (Optional[str]) -> bool + mt = (ct or "").split(";", 1)[0] + return mt == "application/x-www-form-urlencoded" + + +# Adapted from Starlette and adapted to work in a synchronous context +# https://github.com/Kludex/starlette/blob/main/starlette/formparsers.py + + +class FormMessage(Enum): + FIELD_START = 1 + FIELD_NAME = 2 + FIELD_DATA = 3 + FIELD_END = 4 + END = 5 + + +@dataclass +class MultipartPart: + content_disposition: bytes | None = None + field_name: str = "" + data: bytearray = field(default_factory=bytearray) + file: UploadFile | None = None + item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) + + +def _user_safe_decode(src: bytes | bytearray, codec: str) -> str: + try: + return src.decode(codec) + except (UnicodeDecodeError, LookupError): + return src.decode("latin-1") + + +class MultiPartException(Exception): + def __init__(self, message: str) -> None: + self.message = message + + +class FormParser: + def __init__(self, headers: Headers, content: bytes) -> None: + assert ( + multipart is not None + ), "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.content = content + self.messages: list[tuple[FormMessage, bytes]] = [] + + def on_field_start(self) -> None: + message = (FormMessage.FIELD_START, b"") + self.messages.append(message) + + def on_field_name(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_NAME, data[start:end]) + self.messages.append(message) + + def on_field_data(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_DATA, data[start:end]) + self.messages.append(message) + + def on_field_end(self) -> None: + message = (FormMessage.FIELD_END, b"") + self.messages.append(message) + + def on_end(self) -> None: + message = (FormMessage.END, b"") + self.messages.append(message) + + def parse(self) -> FormData: + # Callbacks dictionary. + callbacks: QuerystringCallbacks = { + "on_field_start": self.on_field_start, + "on_field_name": self.on_field_name, + "on_field_data": self.on_field_data, + "on_field_end": self.on_field_end, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.QuerystringParser(callbacks) + field_name = b"" + field_value = b"" + + items: list[tuple[str, str | UploadFile]] = [] + + parser.write(self.content) + + messages = list(self.messages) + self.messages.clear() + for message_type, message_bytes in messages: + if message_type == FormMessage.FIELD_START: + field_name = b"" + field_value = b"" + elif message_type == FormMessage.FIELD_NAME: + field_name += message_bytes + elif message_type == FormMessage.FIELD_DATA: + field_value += message_bytes + elif message_type == FormMessage.FIELD_END: + name = unquote_plus(field_name.decode("latin-1")) + value = unquote_plus(field_value.decode("latin-1")) + items.append((name, value)) + + return FormData(items) + + +class MultiPartParser: + spool_max_size = 1024 * 1024 # 1MB + """The maximum size of the spooled temporary file used to store file data.""" + max_part_size = 1024 * 1024 # 1MB + """The maximum size of a part in the multipart request.""" + + def __init__( + self, + headers: Headers, + content: bytes, + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_part_size: int = 1024 * 1024, # 1MB + ) -> None: + assert ( + multipart is not None + ), "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.content = content + self.max_files = max_files + self.max_fields = max_fields + self.items: list[tuple[str, str | UploadFile]] = [] + self._current_files = 0 + self._current_fields = 0 + self._current_partial_header_name: bytes = b"" + self._current_partial_header_value: bytes = b"" + self._current_part = MultipartPart() + self._charset = "" + self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] + self._file_parts_to_finish: list[MultipartPart] = [] + self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] + self.max_part_size = max_part_size + + def on_part_begin(self) -> None: + self._current_part = MultipartPart() + + def on_part_data(self, data: bytes, start: int, end: int) -> None: + message_bytes = data[start:end] + if self._current_part.file is None: + if len(self._current_part.data) + len(message_bytes) > self.max_part_size: + raise MultiPartException( + f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB." + ) + self._current_part.data.extend(message_bytes) + else: + self._file_parts_to_write.append((self._current_part, message_bytes)) + + def on_part_end(self) -> None: + if self._current_part.file is None: + self.items.append( + ( + self._current_part.field_name, + _user_safe_decode(self._current_part.data, self._charset), + ) + ) + else: + self._file_parts_to_finish.append(self._current_part) + # The file can be added to the items right now even though it's not + # finished yet, because it will be finished in the `parse()` method, before + # self.items is used in the return value. + self.items.append((self._current_part.field_name, self._current_part.file)) + + def on_header_field(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_name += data[start:end] + + def on_header_value(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_value += data[start:end] + + def on_header_end(self) -> None: + field = self._current_partial_header_name.lower() + if field == b"content-disposition": + self._current_part.content_disposition = self._current_partial_header_value + self._current_part.item_headers.append( + (field, self._current_partial_header_value) + ) + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" + + def on_headers_finished(self) -> None: + disposition, options = parse_options_header( + self._current_part.content_disposition + ) + try: + self._current_part.field_name = _user_safe_decode( + options[b"name"], self._charset + ) + except KeyError: + raise MultiPartException( + 'The Content-Disposition header field "name" must be provided.' + ) + if b"filename" in options: + self._current_files += 1 + if self._current_files > self.max_files: + raise MultiPartException( + f"Too many files. Maximum number of files is {self.max_files}." + ) + filename = _user_safe_decode(options[b"filename"], self._charset) + tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) + self._files_to_close_on_error.append(tempfile) + self._current_part.file = UploadFile( + file=tempfile, # type: ignore[arg-type] + size=0, + filename=filename, + headers=Headers(raw=self._current_part.item_headers), + ) + else: + self._current_fields += 1 + if self._current_fields > self.max_fields: + raise MultiPartException( + f"Too many fields. Maximum number of fields is {self.max_fields}." + ) + self._current_part.file = None + + def on_end(self) -> None: + pass + + def parse(self) -> FormData: + # Parse the Content-Type header to get the multipart boundary. + _, params = parse_options_header(self.headers["Content-Type"]) + charset = params.get(b"charset", "utf-8") + if isinstance(charset, bytes): + charset = charset.decode("latin-1") + self._charset = charset + try: + boundary = params[b"boundary"] + except KeyError: + raise MultiPartException("Missing boundary in multipart.") + + # Callbacks dictionary. + callbacks: MultipartCallbacks = { + "on_part_begin": self.on_part_begin, + "on_part_data": self.on_part_data, + "on_part_end": self.on_part_end, + "on_header_field": self.on_header_field, + "on_header_value": self.on_header_value, + "on_header_end": self.on_header_end, + "on_headers_finished": self.on_headers_finished, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.MultipartParser(boundary, callbacks) + try: + # Feed the parser with data from the request. + parser.write(self.content) + # Write file data, it needs to use await with the UploadFile methods + # that call the corresponding file methods *in a threadpool*, + # otherwise, if they were called directly in the callback methods above + # (regular, non-async functions), that would block the event loop in + # the main thread. + for part, data in self._file_parts_to_write: + assert part.file # for type checkers + part.file.file.write(data) + for part in self._file_parts_to_finish: + assert part.file # for type checkers + part.file.file.seek(0) + self._file_parts_to_write.clear() + self._file_parts_to_finish.clear() + except MultiPartException as exc: + # Close all the files if there was an error. + for file in self._files_to_close_on_error: + file.close() + raise exc + + parser.finalize() + return FormData(self.items) + + class StarletteRequestExtractor: """ Extracts useful information from the Starlette request @@ -600,8 +926,8 @@ def extract_cookies_from_request(self): return cookies - async def extract_request_info(self): - # type: (StarletteRequestExtractor) -> Optional[Dict[str, Any]] + def extract_request_info(self, scope): + # type: (StarletteRequestExtractor, StarletteScope) -> Optional[Dict[str, Any]] client = sentry_sdk.get_client() request_info = {} # type: Dict[str, Any] @@ -612,7 +938,7 @@ async def extract_request_info(self): request_info["cookies"] = self.cookies() # If there is no body, just return the cookies - content_length = await self.content_length() + content_length = self.content_length() if not content_length: return request_info @@ -623,17 +949,18 @@ async def extract_request_info(self): request_info["data"] = AnnotatedValue.removed_because_over_size_limit() return request_info - # Add JSON body, if it is a JSON request - json = await self.json() - if json: - request_info["data"] = json + if "state" not in scope: + return request_info + + state = scope["state"] + + if "sentry_sdk.json" in state: + request_info["data"] = deepcopy(state["sentry_sdk.json"]) return request_info - # Add form as key/value pairs, if request has form data - form = await self.form() - if form: + if "sentry_sdk.form" in state: form_data = {} - for key, val in form.items(): + for key, val in state["sentry_sdk.form"].items(): is_file = isinstance(val, UploadFile) form_data[key] = ( val @@ -644,11 +971,49 @@ async def extract_request_info(self): request_info["data"] = form_data return request_info - # Raw data, do not add body just an annotation - request_info["data"] = AnnotatedValue.removed_because_raw_data() + if "sentry_sdk.raw_body" in state and _is_json_content_type( + self.request.headers.get("content-type") + ): + try: + request_info["data"] = json.loads( + state["sentry_sdk.raw_body"].decode("utf-8") + ) + except JSONDecodeError: + return request_info + return request_info + + if "sentry_sdk.raw_body" in state and _is_form_data_encoded( + self.request.headers.get("content-type") + ): + + # try: + multipart_parser = MultiPartParser( + self.request.headers, + state["sentry_sdk.raw_body"], + ) + request_info["data"] = multipart_parser.parse() + # except MultiPartException as exc: + # if "app" in self.scope: + # raise HTTPException(status_code=400, detail=exc.message) + # raise exc + + return request_info + + elif "sentry_sdk.raw_body" in state and _is_form_data_encoded( + self.request.headers.get("content-type") + ): + request_info["data"] = AnnotatedValue.removed_because_raw_data() + return request_info + + if "sentry_sdk.raw_body" in state and _is_form_urlencoded( + self.request.headers.get("content-type") + ): + request_info["data"] = parse_qsl(state["sentry_sdk.raw_body"]) + return request_info + return request_info - async def content_length(self): + def content_length(self): # type: (StarletteRequestExtractor) -> Optional[int] if "content-length" in self.request.headers: return int(self.request.headers["content-length"]) @@ -659,32 +1024,10 @@ def cookies(self): # type: (StarletteRequestExtractor) -> Dict[str, Any] return self.request.cookies - async def form(self): - # type: (StarletteRequestExtractor) -> Any - if multipart is None: - return None - - # Parse the body first to get it cached, as Starlette does not cache form() as it - # does with body() and json() https://github.com/encode/starlette/discussions/1933 - # Calling `.form()` without calling `.body()` first will - # potentially break the users project. - await self.request.body() - - return await self.request.form() - def is_json(self): # type: (StarletteRequestExtractor) -> bool return _is_json_content_type(self.request.headers.get("content-type")) - async def json(self): - # type: (StarletteRequestExtractor) -> Optional[Dict[str, Any]] - if not self.is_json(): - return None - try: - return await self.request.json() - except JSONDecodeError: - return None - def _transaction_name_from_router(scope): # type: (StarletteScope) -> Optional[str] diff --git a/tests/integrations/starlette/test_starlette.py b/tests/integrations/starlette/test_starlette.py index bc445bf8f2..362e285e4b 100644 --- a/tests/integrations/starlette/test_starlette.py +++ b/tests/integrations/starlette/test_starlette.py @@ -291,7 +291,7 @@ async def test_starletterequestextractor_content_length(sentry_init): starlette_request = starlette.requests.Request(scope) extractor = StarletteRequestExtractor(starlette_request) - assert await extractor.content_length() == len(json.dumps(BODY_JSON)) + assert extractor.content_length() == len(json.dumps(BODY_JSON)) @pytest.mark.asyncio @@ -305,48 +305,6 @@ async def test_starletterequestextractor_cookies(sentry_init): } -@pytest.mark.asyncio -async def test_starletterequestextractor_json(sentry_init): - starlette_request = starlette.requests.Request(SCOPE) - - # Mocking async `_receive()` that works in Python 3.7+ - side_effect = [_mock_receive(msg) for msg in JSON_RECEIVE_MESSAGES] - starlette_request._receive = mock.Mock(side_effect=side_effect) - - extractor = StarletteRequestExtractor(starlette_request) - - assert extractor.is_json() - assert await extractor.json() == BODY_JSON - - -@pytest.mark.asyncio -async def test_starletterequestextractor_form(sentry_init): - scope = SCOPE.copy() - scope["headers"] = [ - [b"content-type", b"multipart/form-data; boundary=fd721ef49ea403a6"], - ] - # TODO add test for content-type: "application/x-www-form-urlencoded" - - starlette_request = starlette.requests.Request(scope) - - # Mocking async `_receive()` that works in Python 3.7+ - side_effect = [_mock_receive(msg) for msg in FORM_RECEIVE_MESSAGES] - starlette_request._receive = mock.Mock(side_effect=side_effect) - - extractor = StarletteRequestExtractor(starlette_request) - - form_data = await extractor.form() - assert form_data.keys() == PARSED_FORM.keys() - assert form_data["username"] == PARSED_FORM["username"] - assert form_data["password"] == PARSED_FORM["password"] - assert form_data["photo"].filename == PARSED_FORM["photo"].filename - - # Make sure we still can read the body - # after alreading it with extractor.form() above. - body = await extractor.request.body() - assert body - - @pytest.mark.asyncio async def test_starletterequestextractor_body_consumed_twice( sentry_init, capture_events @@ -405,7 +363,7 @@ async def test_starletterequestextractor_extract_request_info_too_big(sentry_ini extractor = StarletteRequestExtractor(starlette_request) - request_info = await extractor.extract_request_info() + request_info = extractor.extract_request_info(scope) assert request_info assert request_info["cookies"] == { @@ -437,7 +395,9 @@ async def test_starletterequestextractor_extract_request_info(sentry_init): extractor = StarletteRequestExtractor(starlette_request) - request_info = await extractor.extract_request_info() + scope["state"] = {} + scope["state"]["sentry_sdk.json"] = BODY_JSON + request_info = extractor.extract_request_info(scope) assert request_info assert request_info["cookies"] == { @@ -447,6 +407,69 @@ async def test_starletterequestextractor_extract_request_info(sentry_init): assert request_info["data"] == BODY_JSON +@pytest.mark.asyncio +async def test_starletterequestextractor_extract_request_info_json_bytes(sentry_init): + sentry_init( + send_default_pii=True, + integrations=[StarletteIntegration()], + ) + scope = SCOPE.copy() + scope["headers"] = [ + [b"content-type", b"application/json"], + [b"content-length", str(len(json.dumps(BODY_JSON))).encode()], + [b"cookie", b"yummy_cookie=choco; tasty_cookie=strawberry"], + ] + + starlette_request = starlette.requests.Request(scope) + + # Mocking async `_receive()` that works in Python 3.7+ + side_effect = [_mock_receive(msg) for msg in JSON_RECEIVE_MESSAGES] + starlette_request._receive = mock.Mock(side_effect=side_effect) + + extractor = StarletteRequestExtractor(starlette_request) + + scope["state"] = {} + scope["state"]["sentry_sdk.raw_body"] = json.dumps(BODY_JSON).encode("utf-8") + request_info = extractor.extract_request_info(scope) + + assert request_info["data"] == BODY_JSON + + +@pytest.mark.asyncio +async def test_starletterequestextractor_form_bytes(sentry_init): + sentry_init( + send_default_pii=True, + integrations=[StarletteIntegration()], + max_request_body_size="always", + ) + scope = SCOPE.copy() + scope["headers"] = [ + [b"content-type", b"multipart/form-data; boundary=fd721ef49ea403a6"], + [b"content-length", str(len(BODY_FORM)).encode()], + ] + # TODO add test for content-type: "application/x-www-form-urlencoded" + + starlette_request = starlette.requests.Request(scope) + + # Mocking async `_receive()` that works in Python 3.7+ + side_effect = [_mock_receive(msg) for msg in FORM_RECEIVE_MESSAGES] + starlette_request._receive = mock.Mock(side_effect=side_effect) + + extractor = StarletteRequestExtractor(starlette_request) + + scope["state"] = {} + scope["state"]["sentry_sdk.raw_body"] = BODY_FORM.encode("utf-8") + request_info = extractor.extract_request_info(scope) + print(request_info) + + form_data = request_info["data"] + + assert form_data.keys() == PARSED_FORM.keys() + assert form_data["username"] == PARSED_FORM["username"] + assert form_data["password"] == PARSED_FORM["password"] + assert form_data["photo"].filename == PARSED_FORM["photo"].filename + + @pytest.mark.asyncio async def test_starletterequestextractor_extract_request_info_no_pii(sentry_init): sentry_init( @@ -468,7 +491,9 @@ async def test_starletterequestextractor_extract_request_info_no_pii(sentry_init extractor = StarletteRequestExtractor(starlette_request) - request_info = await extractor.extract_request_info() + scope["state"] = {} + scope["state"]["sentry_sdk.json"] = BODY_JSON + request_info = extractor.extract_request_info(scope) assert request_info assert "cookies" not in request_info @@ -1358,24 +1383,32 @@ async def _error(_): @pytest.mark.asyncio async def test_starletterequestextractor_malformed_json_error_handling(sentry_init): - scope = SCOPE.copy() - scope["headers"] = [ - [b"content-type", b"application/json"], - ] - starlette_request = starlette.requests.Request(scope) - + sentry_init( + send_default_pii=True, + integrations=[StarletteIntegration()], + ) malformed_json = "{invalid json" malformed_messages = [ {"type": "http.request", "body": malformed_json.encode("utf-8")}, {"type": "http.disconnect"}, ] + scope = SCOPE.copy() + scope["headers"] = [ + [b"content-type", b"application/json"], + [b"content-length", str(len(json.dumps(malformed_json))).encode()], + [b"cookie", b"yummy_cookie=choco; tasty_cookie=strawberry"], + ] + + starlette_request = starlette.requests.Request(scope) + side_effect = [_mock_receive(msg) for msg in malformed_messages] starlette_request._receive = mock.Mock(side_effect=side_effect) extractor = StarletteRequestExtractor(starlette_request) - assert extractor.is_json() + scope["state"] = {} + scope["state"]["sentry_sdk.raw_body"] = malformed_json.encode("utf-8") + request_info = extractor.extract_request_info(scope) - result = await extractor.json() - assert result is None + assert request_info and "data" not in request_info