Skip to content

Commit 4ea48ae

Browse files
committed
In progress
1 parent e462d3c commit 4ea48ae

File tree

5 files changed

+172
-2
lines changed

5 files changed

+172
-2
lines changed

src/stac_auth_proxy/app.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fastapi import FastAPI
1212

1313
from .config import Settings
14-
from .handlers import HealthzHandler, ReverseProxyHandler
14+
from .handlers import HealthzHandler, ReverseProxyHandler, S3AssetSigner
1515
from .lifespan import LifespanManager, ServerHealthCheck
1616
from .middleware import (
1717
AddProcessTimeHeaderMiddleware,
@@ -57,6 +57,14 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
5757
prefix=settings.healthz_prefix,
5858
)
5959

60+
if settings.signer_endpoint:
61+
# TODO: Warn/error if endpoint is public
62+
app.add_api_route(
63+
settings.signer_endpoint,
64+
S3AssetSigner(bucket_pattern=settings.signer_endpoint).endpoint,
65+
methods=["POST"],
66+
)
67+
6068
app.add_api_route(
6169
"/{path:path}",
6270
ReverseProxyHandler(upstream=str(settings.upstream_url)).stream,
@@ -76,6 +84,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
7684
default_public=settings.default_public,
7785
)
7886

87+
# signers={
88+
# schema: endpoint
89+
# for schema, endpoint in {"s3": settings.signer_endpoint}.items()
90+
# if endpoint
91+
# },
7992
if settings.items_filter:
8093
app.add_middleware(
8194
ApplyCql2FilterMiddleware,

src/stac_auth_proxy/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class Settings(BaseSettings):
4545
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
4646
openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
4747

48+
signer_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
49+
signer_asset_expression: Optional[str] = Field(default=r".*")
50+
4851
# Auth
4952
default_public: bool = False
5053
public_endpoints: EndpointMethodsNoScope = {

src/stac_auth_proxy/handlers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
from .healthz import HealthzHandler
44
from .reverse_proxy import ReverseProxyHandler
5+
from .s3_asset_signer import S3AssetSigner
56

6-
__all__ = ["ReverseProxyHandler", "HealthzHandler"]
7+
__all__ = ["ReverseProxyHandler", "HealthzHandler", "S3AssetSigner"]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import logging
2+
import re
3+
from dataclasses import dataclass
4+
5+
import boto3
6+
from botocore.exceptions import ClientError
7+
from fastapi import HTTPException
8+
9+
10+
@dataclass
11+
class S3AssetSigner:
12+
bucket_pattern: str = r".*"
13+
14+
def endpoint(self, payload: "S3AssetSignerPayload", expiration: int = 3600) -> str:
15+
"""Generate a presigned URL to share an S3 object."""
16+
if not re.match(self.bucket_pattern, payload.bucket_name):
17+
return HTTPException(status_code=404, detail="Item not found")
18+
19+
try:
20+
return boto3.client("s3").generate_presigned_url(
21+
"get_object",
22+
Params={"Bucket": payload.bucket_name, "Key": payload.object_name},
23+
ExpiresIn=expiration,
24+
)
25+
except ClientError as e:
26+
logging.error(e)
27+
return HTTPException(status_code=500, detail="Internal server error")
28+
29+
30+
@dataclass
31+
class S3AssetSignerPayload:
32+
"""Signs S3 assets."""
33+
34+
bucket_name: str
35+
object_name: str
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
2+
3+
import gzip
4+
import zlib
5+
from dataclasses import dataclass, field
6+
from typing import Any
7+
8+
import brotli
9+
from starlette.requests import Request
10+
from starlette.types import ASGIApp, Receive, Scope, Send
11+
12+
from ..utils.requests import find_match
13+
14+
ENCODING_HANDLERS = {
15+
"gzip": gzip,
16+
"deflate": zlib,
17+
"br": brotli,
18+
}
19+
20+
21+
@dataclass(frozen=True)
22+
class AuthorizationExtension:
23+
"""Middleware to add the OpenAPI spec to the response."""
24+
25+
app: ASGIApp
26+
signers: dict[str, str] = field(default_factory=dict)
27+
28+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
29+
"""Add the OpenAPI spec to the response."""
30+
if scope["type"] != "http" or Request(scope).url.path != self.openapi_spec_path:
31+
return await self.app(scope, receive, send)
32+
33+
# TODO: test if asset path matches
34+
# start_message: Optional[Message] = None
35+
# body = b""
36+
37+
# async def augment_oidc_spec(message: Message):
38+
# nonlocal start_message
39+
# nonlocal body
40+
# if message["type"] == "http.response.start":
41+
# # NOTE: Because we are modifying the response body, we will need to update
42+
# # the content-length header. However, headers are sent before we see the
43+
# # body. To handle this, we delay sending the http.response.start message
44+
# # until after we alter the body.
45+
# start_message = message
46+
# return
47+
# elif message["type"] != "http.response.body":
48+
# return await send(message)
49+
50+
# body += message["body"]
51+
52+
# # Skip body chunks until all chunks have been received
53+
# if message["more_body"]:
54+
# return
55+
56+
# # Maybe decompress the body
57+
# headers = MutableHeaders(scope=start_message)
58+
# content_encoding = headers.get("content-encoding", "").lower()
59+
# handler = None
60+
# if content_encoding:
61+
# handler = ENCODING_HANDLERS.get(content_encoding)
62+
# assert handler, f"Unsupported content encoding: {content_encoding}"
63+
# body = (
64+
# handler.decompress(body)
65+
# if content_encoding != "deflate"
66+
# else handler.decompress(body, -zlib.MAX_WBITS)
67+
# )
68+
69+
# # Augment the spec
70+
# body = dict_to_bytes(self.augment_spec(json.loads(body)))
71+
72+
# # Maybe re-compress the body
73+
# if handler:
74+
# body = handler.compress(body)
75+
76+
# # Update the content-length header
77+
# headers["content-length"] = str(len(body))
78+
# assert start_message, "Expected start_message to be set"
79+
# start_message["headers"] = [
80+
# (key.encode(), value.encode()) for key, value in headers.items()
81+
# ]
82+
83+
# # Send http.response.start
84+
# await send(start_message)
85+
86+
# # Send http.response.body
87+
# await send(
88+
# {
89+
# "type": "http.response.body",
90+
# "body": body,
91+
# "more_body": False,
92+
# }
93+
# )
94+
95+
return await self.app(scope, receive, augment_oidc_spec)
96+
97+
def augment_spec(self, openapi_spec) -> dict[str, Any]:
98+
"""Augment the OpenAPI spec with auth information."""
99+
components = openapi_spec.setdefault("components", {})
100+
securitySchemes = components.setdefault("securitySchemes", {})
101+
securitySchemes[self.oidc_auth_scheme_name] = {
102+
"type": "openIdConnect",
103+
"openIdConnectUrl": self.oidc_config_url,
104+
}
105+
for path, method_config in openapi_spec["paths"].items():
106+
for method, config in method_config.items():
107+
match = find_match(
108+
path,
109+
method,
110+
self.private_endpoints,
111+
self.public_endpoints,
112+
self.default_public,
113+
)
114+
if match.is_private:
115+
config.setdefault("security", []).append(
116+
{self.oidc_auth_scheme_name: match.required_scopes}
117+
)
118+
return openapi_spec

0 commit comments

Comments
 (0)