Skip to content

Commit 0b7032b

Browse files
committed
Handle compressed openapi specs
TODO: Need test
1 parent 00fe51b commit 0b7032b

File tree

1 file changed

+55
-11
lines changed

1 file changed

+55
-11
lines changed

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22

3+
import brotli
4+
import gzip
35
import json
6+
import re
7+
import zlib
48
from dataclasses import dataclass
59
from typing import Any
610

11+
from starlette.datastructures import MutableHeaders
712
from starlette.requests import Request
813
from starlette.types import ASGIApp, Message, Receive, Scope, Send
914

1015
from ..config import EndpointMethods
1116
from ..utils.requests import dict_to_bytes
1217

1318

19+
ENCODING_HANDLERS = {
20+
"gzip": gzip,
21+
"deflate": zlib,
22+
"br": brotli,
23+
}
24+
25+
1426
@dataclass(frozen=True)
1527
class OpenApiMiddleware:
1628
"""Middleware to add the OpenAPI spec to the response."""
@@ -27,26 +39,58 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2739
if scope["type"] != "http" or Request(scope).url.path != self.openapi_spec_path:
2840
return await self.app(scope, receive, send)
2941

30-
total_body = b""
42+
start_message = None
43+
body = b""
3144

3245
async def augment_oidc_spec(message: Message):
33-
if message["type"] != "http.response.body":
46+
nonlocal start_message
47+
nonlocal body
48+
if message["type"] == "http.response.start":
49+
# NOTE: Because we are modifying the response body, we will need to update
50+
# the content-length header. However, headers are sent before we see the
51+
# body. To handle this, we delay sending the http.response.start message
52+
# until after we alter the body.
53+
start_message = message
54+
return
55+
elif message["type"] != "http.response.body":
3456
return await send(message)
3557

36-
# TODO: Make more robust to handle non-JSON responses
37-
38-
nonlocal total_body
58+
body += message["body"]
3959

40-
total_body += message["body"]
41-
42-
# Pass empty body chunks until all chunks have been received
60+
# Skip body chunks until all chunks have been received
4361
if message["more_body"]:
44-
return await send({**message, "body": b""})
45-
62+
return
63+
64+
# Maybe decompress the body
65+
headers = MutableHeaders(scope=start_message)
66+
content_encoding = headers.get("content-encoding", "").lower()
67+
handler = None
68+
if content_encoding:
69+
handler = ENCODING_HANDLERS.get(content_encoding)
70+
assert handler, f"Unsupported content encoding: {content_encoding}"
71+
body = handler.decompress(body)
72+
73+
# Augment the spec
74+
body = dict_to_bytes(self.augment_spec(json.loads(body)))
75+
76+
# Maybe re-compress the body
77+
if handler:
78+
body = handler.compress(body)
79+
80+
# Update the content-length header
81+
headers["content-length"] = str(len(body))
82+
start_message["headers"] = [
83+
(key.encode(), value.encode()) for key, value in headers.items()
84+
]
85+
86+
# Send http.response.start
87+
await send(start_message)
88+
89+
# Send http.response.body
4690
await send(
4791
{
4892
"type": "http.response.body",
49-
"body": dict_to_bytes(self.augment_spec(json.loads(total_body))),
93+
"body": body,
5094
"more_body": False,
5195
}
5296
)

0 commit comments

Comments
 (0)