diff --git a/sentry_sdk/_werkzeug.py b/sentry_sdk/_werkzeug.py index 0fa3d611f1..48d6313bce 100644 --- a/sentry_sdk/_werkzeug.py +++ b/sentry_sdk/_werkzeug.py @@ -41,58 +41,89 @@ # -# `get_headers` comes from `werkzeug.datastructures.EnvironHeaders` -# https://github.com/pallets/werkzeug/blob/0.14.1/werkzeug/datastructures.py#L1361 +# `get_headers` comes from `werkzeug.datastructures.headers.__iter__` +# https://github.com/pallets/werkzeug/blob/3.1.3/src/werkzeug/datastructures/headers.py#L644 # # We need this function because Django does not give us a "pure" http header # dict. So we might as well use it for all WSGI integrations. # def _get_headers(environ): # type: (Dict[str, str]) -> Iterator[Tuple[str, str]] - """ - Returns only proper HTTP headers. - """ for key, value in environ.items(): - key = str(key) - if key.startswith("HTTP_") and key not in ( + if key.startswith("HTTP_") and key not in { "HTTP_CONTENT_TYPE", "HTTP_CONTENT_LENGTH", - ): + }: yield key[5:].replace("_", "-").title(), value - elif key in ("CONTENT_TYPE", "CONTENT_LENGTH"): + elif key in {"CONTENT_TYPE", "CONTENT_LENGTH"} and value: yield key.replace("_", "-").title(), value # # `get_host` comes from `werkzeug.wsgi.get_host` -# https://github.com/pallets/werkzeug/blob/1.0.1/src/werkzeug/wsgi.py#L145 +# https://github.com/pallets/werkzeug/blob/3.1.3/src/werkzeug/wsgi.py#L86 # def get_host(environ, use_x_forwarded_for=False): # type: (Dict[str, str], bool) -> str """ Return the host for the given WSGI environment. """ - if use_x_forwarded_for and "HTTP_X_FORWARDED_HOST" in environ: - rv = environ["HTTP_X_FORWARDED_HOST"] - if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"): - rv = rv[:-3] - elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"): - rv = rv[:-4] - elif environ.get("HTTP_HOST"): - rv = environ["HTTP_HOST"] - if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"): - rv = rv[:-3] - elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"): - rv = rv[:-4] - elif environ.get("SERVER_NAME"): - rv = environ["SERVER_NAME"] - if (environ["wsgi.url_scheme"], environ["SERVER_PORT"]) not in ( - ("https", "443"), - ("http", "80"), - ): - rv += ":" + environ["SERVER_PORT"] - else: - # In spite of the WSGI spec, SERVER_NAME might not be present. - rv = "unknown" - - return rv + return _get_host( + environ["wsgi.url_scheme"], + ( + environ["HTTP_X_FORWARDED_HOST"] + if use_x_forwarded_for and environ.get("HTTP_X_FORWARDED_HOST") + else environ.get("HTTP_HOST") + ), + _get_server(environ), + ) + + +# `_get_host` comes from `werkzeug.sansio.utils` +# https://github.com/pallets/werkzeug/blob/3.1.3/src/werkzeug/sansio/utils.py#L49 +def _get_host( + scheme, + host_header, + server=None, +): + # type: (str, str | None, Tuple[str, int | None] | None) -> str + """ + Return the host for the given parameters. + """ + host = "" + + if host_header is not None: + host = host_header + elif server is not None: + host = server[0] + + # If SERVER_NAME is IPv6, wrap it in [] to match Host header. + # Check for : because domain or IPv4 can't have that. + if ":" in host and host[0] != "[": + host = f"[{host}]" + + if server[1] is not None: + host = f"{host}:{server[1]}" # noqa: E231 + + if scheme in {"http", "ws"} and host.endswith(":80"): + host = host[:-3] + elif scheme in {"https", "wss"} and host.endswith(":443"): + host = host[:-4] + + return host + + +def _get_server(environ): + # type: (Dict[str, str]) -> Tuple[str, int | None] | None + name = environ.get("SERVER_NAME") + + if name is None: + return None + + try: + port = int(environ.get("SERVER_PORT", None)) # type: ignore[arg-type] + except (TypeError, ValueError): + # unix socket + port = None + + return name, port diff --git a/tests/integrations/wsgi/test_wsgi.py b/tests/integrations/wsgi/test_wsgi.py index a741d1c57b..a281469905 100644 --- a/tests/integrations/wsgi/test_wsgi.py +++ b/tests/integrations/wsgi/test_wsgi.py @@ -7,6 +7,7 @@ import sentry_sdk from sentry_sdk import capture_message from sentry_sdk.integrations.wsgi import SentryWsgiMiddleware +from sentry_sdk._werkzeug import get_host, _get_headers @pytest.fixture @@ -39,6 +40,65 @@ def next(self): return type(self).__next__(self) +@pytest.mark.parametrize( + ("environ", "expect"), + ( + pytest.param({"HTTP_HOST": "spam"}, "spam", id="host"), + pytest.param({"HTTP_HOST": "spam:80"}, "spam", id="host, strip http port"), + pytest.param( + {"wsgi.url_scheme": "https", "HTTP_HOST": "spam:443"}, + "spam", + id="host, strip https port", + ), + pytest.param({"HTTP_HOST": "spam:8080"}, "spam:8080", id="host, custom port"), + pytest.param( + {"HTTP_HOST": "spam", "SERVER_NAME": "eggs", "SERVER_PORT": "80"}, + "spam", + id="prefer host", + ), + pytest.param( + {"SERVER_NAME": "eggs", "SERVER_PORT": "80"}, + "eggs", + id="name, ignore http port", + ), + pytest.param( + {"wsgi.url_scheme": "https", "SERVER_NAME": "eggs", "SERVER_PORT": "443"}, + "eggs", + id="name, ignore https port", + ), + pytest.param( + {"SERVER_NAME": "eggs", "SERVER_PORT": "8080"}, + "eggs:8080", + id="name, custom port", + ), + pytest.param( + {"HTTP_HOST": "ham", "HTTP_X_FORWARDED_HOST": "eggs"}, + "ham", + id="ignore x-forwarded-host", + ), + pytest.param( + { + "SERVER_NAME": "2001:0db8:85a3:0042:1000:8a2e:0370:7334", + "SERVER_PORT": "8080", + }, + "[2001:0db8:85a3:0042:1000:8a2e:0370:7334]:8080", + id="IPv6, custom port", + ), + pytest.param( + {"SERVER_NAME": "eggs"}, + "eggs", + id="name, no port", + ), + ), +) +# +# https://github.com/pallets/werkzeug/blob/main/tests/test_wsgi.py#L60 +# +def test_get_host(environ, expect): + environ.setdefault("wsgi.url_scheme", "http") + assert get_host(environ) == expect + + def test_basic(sentry_init, crashing_app, capture_events): sentry_init(send_default_pii=True) app = SentryWsgiMiddleware(crashing_app) @@ -61,6 +121,21 @@ def test_basic(sentry_init, crashing_app, capture_events): } +@pytest.mark.parametrize( + ("environ", "expect"), + ( + pytest.param( + {"CONTENT_TYPE": "text/html", "CONTENT_LENGTH": "0"}, + [("Content-Length", "0"), ("Content-Type", "text/html")], + id="headers", + ), + ), +) +def test_headers(environ, expect): + environ.setdefault("wsgi.url_scheme", "http") + assert sorted(_get_headers(environ)) == expect + + @pytest.mark.parametrize("path_info", ("bark/", "/bark/")) @pytest.mark.parametrize("script_name", ("woof/woof", "woof/woof/")) def test_script_name_is_respected(