Skip to content

Commit 4ccbf24

Browse files
committed
Support compressed bodies
1 parent 9231fa3 commit 4ccbf24

File tree

1 file changed

+48
-25
lines changed

1 file changed

+48
-25
lines changed

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,21 @@
77
from dataclasses import dataclass
88
from typing import Any
99

10+
from starlette.datastructures import MutableHeaders
1011
from starlette.requests import Request
1112
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1213

1314
from ..config import EndpointMethods
1415
from ..utils.requests import dict_to_bytes
1516

1617

18+
ENCODING_HANDLERS = {
19+
"gzip": gzip,
20+
"deflate": zlib,
21+
"br": brotli,
22+
}
23+
24+
1725
@dataclass(frozen=True)
1826
class OpenApiMiddleware:
1927
"""Middleware to add the OpenAPI spec to the response."""
@@ -30,41 +38,56 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3038
if scope["type"] != "http" or Request(scope).url.path != self.openapi_spec_path:
3139
return await self.app(scope, receive, send)
3240

33-
total_body = b""
41+
start_message = None
42+
body = b""
3443

3544
async def augment_oidc_spec(message: Message):
36-
if message["type"] != "http.response.body":
45+
nonlocal start_message
46+
nonlocal body
47+
if message["type"] == "http.response.start":
48+
# NOTE: Because we are modifying the response body, we will need to update
49+
# the content-length header. However, headers are sent before we see the
50+
# body. To handle this, we delay sending the http.response.start message
51+
# until after we alter the body.
52+
start_message = message
53+
return
54+
elif message["type"] != "http.response.body":
3755
return await send(message)
3856

39-
# TODO: Make more robust to handle non-JSON responses
57+
body += message["body"]
4058

41-
nonlocal total_body
59+
# Skip body chunks until all chunks have been received
60+
if message["more_body"]:
61+
return
4262

43-
total_body += message["body"]
63+
# Maybe decompress the body
64+
headers = MutableHeaders(scope=start_message)
65+
content_encoding = headers.get("content-encoding", "").lower()
66+
handler = None
67+
if content_encoding:
68+
handler = ENCODING_HANDLERS.get(content_encoding)
69+
assert handler, f"Unsupported content encoding: {content_encoding}"
70+
body = handler.decompress(body)
4471

45-
# Pass empty body chunks until all chunks have been received
46-
if message["more_body"]:
47-
return await send({**message, "body": b""})
48-
49-
# Handle compressed responses
50-
# content_encoding = (
51-
# message.get("headers", {})
52-
# .get(b"content-encoding", b"")
53-
# .decode()
54-
# .lower()
55-
# )
56-
# if content_encoding:
57-
# if "gzip" in content_encoding:
58-
# total_body = gzip.decompress(total_body)
59-
# elif "deflate" in content_encoding:
60-
# total_body = zlib.decompress(total_body)
61-
# elif "br" in content_encoding:
62-
# total_body = brotli.decompress(total_body)
63-
# print(f"{message=}")
72+
# Augment the spec
73+
body = dict_to_bytes(self.augment_spec(json.loads(body)))
74+
75+
# Maybe re-compress the body
76+
if handler:
77+
body = handler.compress(body)
78+
79+
# Update the content-length header
80+
headers["content-length"] = str(len(body))
81+
start_message["headers"] = headers.items()
82+
83+
# Send http.response.start
84+
await send(start_message)
85+
86+
# Send http.response.body
6487
await send(
6588
{
6689
"type": "http.response.body",
67-
"body": dict_to_bytes(self.augment_spec(json.loads(total_body))),
90+
"body": body,
6891
"more_body": False,
6992
}
7093
)

0 commit comments

Comments
 (0)