Skip to content

Commit d2f442a

Browse files
committed
refactor: decompress response within reverse-proxy tooling
1 parent 21cf1c2 commit d2f442a

File tree

5 files changed

+21
-36
lines changed

5 files changed

+21
-36
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"jinja2>=3.1.4",
1515
"pydantic-settings>=2.6.1",
1616
"pyjwt>=2.10.1",
17+
"starlette-cramjam>=0.4.0",
1718
"uvicorn>=0.32.1",
1819
]
1920
description = "STAC authentication proxy with FastAPI"

src/stac_auth_proxy/app.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Optional
1010

1111
from fastapi import FastAPI
12+
from starlette_cramjam.middleware import CompressionMiddleware
1213

1314
from .config import Settings
1415
from .handlers import HealthzHandler, ReverseProxyHandler
@@ -56,7 +57,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
5657

5758
app.add_api_route(
5859
"/{path:path}",
59-
ReverseProxyHandler(upstream=str(settings.upstream_url)).stream,
60+
ReverseProxyHandler(upstream=str(settings.upstream_url)).proxy_request,
6061
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
6162
)
6263

@@ -90,6 +91,10 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
9091
oidc_config_url=settings.oidc_discovery_internal_url,
9192
)
9293

94+
app.add_middleware(
95+
CompressionMiddleware,
96+
)
97+
9398
app.add_middleware(
9499
AddProcessTimeHeaderMiddleware,
95100
)

src/stac_auth_proxy/handlers/reverse_proxy.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
import httpx
88
from fastapi import Request
9-
from starlette.background import BackgroundTask
109
from starlette.datastructures import MutableHeaders
11-
from starlette.responses import StreamingResponse
10+
from starlette.responses import Response
1211

1312
logger = logging.getLogger(__name__)
1413

@@ -28,7 +27,7 @@ def __post_init__(self):
2827
timeout=self.timeout,
2928
)
3029

31-
async def proxy_request(self, request: Request) -> httpx.Response:
30+
async def proxy_request(self, request: Request) -> Response:
3231
"""Proxy a request to the upstream STAC API."""
3332
headers = MutableHeaders(request.headers)
3433
headers.setdefault("X-Forwarded-For", request.client.host)
@@ -60,14 +59,15 @@ async def proxy_request(self, request: Request) -> httpx.Response:
6059
f"Received response status {rp_resp.status_code!r} from {rp_req.url} in {proxy_time:.3f}s"
6160
)
6261
rp_resp.headers["X-Upstream-Time"] = f"{proxy_time:.3f}"
63-
return rp_resp
6462

65-
async def stream(self, request: Request) -> StreamingResponse:
66-
"""Transparently proxy a request to the upstream STAC API."""
67-
rp_resp = await self.proxy_request(request)
68-
return StreamingResponse(
69-
rp_resp.aiter_raw(),
63+
# We read the content here to make use of HTTPX's decompression, ensuring we have
64+
# non-compressed content for the middleware to work with.
65+
content = await rp_resp.aread()
66+
if rp_resp.headers.get("Content-Encoding"):
67+
del rp_resp.headers["Content-Encoding"]
68+
69+
return Response(
70+
content=content,
7071
status_code=rp_resp.status_code,
71-
headers=rp_resp.headers,
72-
background=BackgroundTask(rp_resp.aclose),
72+
headers=dict(rp_resp.headers),
7373
)

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,16 @@
11
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22

3-
import gzip
43
import json
5-
import zlib
64
from dataclasses import dataclass
75
from typing import Any, Optional
86

9-
import brotli
107
from starlette.datastructures import MutableHeaders
118
from starlette.requests import Request
129
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1310

1411
from ..config import EndpointMethods
1512
from ..utils.requests import dict_to_bytes, find_match
1613

17-
ENCODING_HANDLERS = {
18-
"gzip": gzip,
19-
"deflate": zlib,
20-
"br": brotli,
21-
}
22-
2314

2415
@dataclass(frozen=True)
2516
class OpenApiMiddleware:
@@ -57,29 +48,15 @@ async def augment_oidc_spec(message: Message):
5748
body += message["body"]
5849

5950
# Skip body chunks until all chunks have been received
60-
if message["more_body"]:
51+
if message.get("more_body"):
6152
return
6253

6354
# Maybe decompress the body
6455
headers = MutableHeaders(scope=start_message)
65-
content_encoding = headers.get("content-encoding", "").lower()
66-
handler = None
67-
if content_encoding:
68-
handler = ENCODING_HANDLERS.get(content_encoding)
69-
assert handler, f"Unsupported content encoding: {content_encoding}"
70-
body = (
71-
handler.decompress(body)
72-
if content_encoding != "deflate"
73-
else handler.decompress(body, -zlib.MAX_WBITS)
74-
)
7556

7657
# Augment the spec
7758
body = dict_to_bytes(self.augment_spec(json.loads(body)))
7859

79-
# Maybe re-compress the body
80-
if handler:
81-
body = handler.compress(body)
82-
8360
# Update the content-length header
8461
headers["content-length"] = str(len(body))
8562
assert start_message, "Expected start_message to be set"

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)