Skip to content

Commit d40af69

Browse files
committed
Working
1 parent 882dd24 commit d40af69

File tree

4 files changed

+93
-75
lines changed

4 files changed

+93
-75
lines changed

src/stac_auth_proxy/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
7979
default_public=settings.default_public,
8080
public_endpoints=settings.public_endpoints,
8181
private_endpoints=settings.private_endpoints,
82+
oidc_config_url=settings.oidc_discovery_internal_url,
8283
)
8384

8485
if settings.openapi_spec_endpoint:

src/stac_auth_proxy/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Settings(BaseSettings):
4646
openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
4747

4848
signer_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
49-
signer_asset_expression: str = Field(default=r".*")
49+
signer_asset_expression: str = Field(default=r"^s3://.*$")
5050

5151
# Auth
5252
default_public: bool = False

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 71 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import logging
44
import re
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
66
from itertools import chain
77
from typing import Any, Optional
88
from urllib.parse import urlparse
99

10+
import httpx
11+
from pydantic import HttpUrl
1012
from starlette.requests import Request
1113
from starlette.types import ASGIApp
1214

@@ -17,7 +19,7 @@
1719
logger = logging.getLogger(__name__)
1820

1921

20-
@dataclass(frozen=True)
22+
@dataclass
2123
class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
2224
"""Middleware to add the authentication extension to the response."""
2325

@@ -30,22 +32,48 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
3032
private_endpoints: EndpointMethods
3133
public_endpoints: EndpointMethods
3234

33-
signing_scheme: str = "signed_url_auth"
34-
auth_scheme: str = "oauth"
35+
oidc_config_url: Optional[HttpUrl] = None
36+
signing_scheme_name: str = "signed_url_auth"
37+
auth_scheme_name: str = "oauth"
38+
auth_scheme: dict[str, Any] = field(default_factory=dict)
39+
extension_url: str = (
40+
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
41+
)
42+
43+
def __post_init__(self):
44+
"""Load after initialization."""
45+
if self.oidc_config_url and not self.auth_scheme:
46+
# Retrieve OIDC configuration and extract authorization and token URLs
47+
oidc_config = httpx.get(str(self.oidc_config_url)).json()
48+
self.auth_scheme = {
49+
"type": "oauth2",
50+
"description": "requires an authentication token",
51+
"flows": {
52+
"authorizationCode": {
53+
"authorizationUrl": oidc_config.get("authorization_endpoint"),
54+
"tokenUrl": oidc_config.get("token_endpoint"),
55+
"scopes": {
56+
k: k
57+
for k in sorted(oidc_config.get("scopes_supported", []))
58+
},
59+
},
60+
},
61+
}
3562

3663
def should_transform_response(self, request: Request) -> bool:
3764
"""Determine if the response should be transformed."""
38-
print(f"{request.url=!s}")
39-
return True
65+
# Match STAC catalog, collection, or item URLs with a single regex
66+
return bool(
67+
re.match(
68+
r"^(/|/collections(/[^/]+(/items/[^/]+)?)?|/search)$", request.url.path
69+
)
70+
)
4071

4172
def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
4273
"""Augment the STAC Item with auth information."""
43-
extension = (
44-
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
45-
)
4674
extensions = doc.setdefault("stac_extensions", [])
47-
if extension not in extensions:
48-
extensions.append(extension)
75+
if self.extension_url not in extensions:
76+
extensions.append(self.extension_url)
4977

5078
# TODO: Should we add this to items even if the assets don't match the asset expression?
5179
# auth:schemes
@@ -55,64 +83,41 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
5583
# - Catalogs
5684
# - Collections
5785
# - Item Properties
58-
# "auth:schemes": {
59-
# "oauth": {
60-
# "type": "oauth2",
61-
# "description": "requires a login and user token",
62-
# "flows": {
63-
# "authorizationUrl": "https://example.com/oauth/authorize",
64-
# "tokenUrl": "https://example.com/oauth/token",
65-
# "scopes": {}
66-
# }
67-
# }
68-
# }
69-
# TODO: Add directly to Collections & Catalogs doc
70-
if "properties" in doc:
71-
schemes = doc["properties"].setdefault("auth:schemes", {})
72-
schemes[self.auth_scheme] = {
73-
"type": "oauth2",
74-
"description": "requires a login and user token",
86+
scheme_loc = doc["properties"] if "properties" in doc else doc
87+
schemes = scheme_loc.setdefault("auth:schemes", {})
88+
schemes[self.auth_scheme_name] = self.auth_scheme
89+
if self.signing_endpoint:
90+
schemes[self.signing_scheme_name] = {
91+
"type": "signedUrl",
92+
"description": "Requires an authentication API",
7593
"flows": {
76-
# TODO: Get authorizationUrl and tokenUrl from config
7794
"authorizationCode": {
78-
"authorizationUrl": "https://example.com/oauth/authorize",
79-
"tokenUrl": "https://example.com/oauth/token",
80-
"scopes": {},
81-
},
82-
},
83-
}
84-
if self.signing_endpoint:
85-
schemes[self.signing_scheme] = {
86-
"type": "signedUrl",
87-
"description": "Requires an authentication API",
88-
"flows": {
89-
"authorizationCode": {
90-
"authorizationApi": self.signing_endpoint,
91-
"method": "POST",
92-
"parameters": {
93-
"bucket": {
94-
"in": "body",
95-
"required": True,
96-
"description": "asset bucket",
97-
"schema": {
98-
"type": "string",
99-
"examples": "example-bucket",
100-
},
95+
"authorizationApi": self.signing_endpoint,
96+
"method": "POST",
97+
"parameters": {
98+
"bucket": {
99+
"in": "body",
100+
"required": True,
101+
"description": "asset bucket",
102+
"schema": {
103+
"type": "string",
104+
"examples": "example-bucket",
101105
},
102-
"key": {
103-
"in": "body",
104-
"required": True,
105-
"description": "asset key",
106-
"schema": {
107-
"type": "string",
108-
"examples": "path/to/example/asset.xyz",
109-
},
106+
},
107+
"key": {
108+
"in": "body",
109+
"required": True,
110+
"description": "asset key",
111+
"schema": {
112+
"type": "string",
113+
"examples": "path/to/example/asset.xyz",
110114
},
111115
},
112-
"responseField": "signed_url",
113-
}
114-
},
115-
}
116+
},
117+
"responseField": "signed_url",
118+
}
119+
},
120+
}
116121

117122
# auth:refs
118123
# ---
@@ -123,7 +128,7 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
123128
logger.warning("Asset %s has no href", asset)
124129
continue
125130
if re.match(self.signed_asset_expression, asset["href"]):
126-
asset.setdefault("auth:refs", []).append(self.signing_scheme)
131+
asset.setdefault("auth:refs", []).append(self.signing_scheme_name)
127132

128133
# Annotate links with "auth:refs": [auth_scheme]
129134
links = chain(
@@ -136,7 +141,6 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
136141
),
137142
)
138143
for link in links:
139-
print(f"{link['href']=!s}")
140144
if "href" not in link:
141145
logger.warning("Link %s has no href", link)
142146
continue
@@ -148,6 +152,6 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
148152
default_public=self.default_public,
149153
)
150154
if match.is_private:
151-
link.setdefault("auth:refs", []).append(self.auth_scheme)
155+
link.setdefault("auth:refs", []).append(self.auth_scheme_name)
152156

153157
return doc

src/stac_auth_proxy/utils/middleware.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22

33
import gzip
44
import json
5+
import re
56
import zlib
67
from abc import ABC, abstractmethod
78
from typing import Any, Optional
89

910
import brotli
10-
from starlette.datastructures import MutableHeaders
11+
from starlette.datastructures import Headers, MutableHeaders
1112
from starlette.requests import Request
1213
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1314

15+
# TODO: Consider using a single middleware to handle all compression/decompression
1416
ENCODING_HANDLERS = {
1517
"gzip": gzip,
1618
"deflate": zlib,
@@ -22,6 +24,9 @@ class JsonResponseMiddleware(ABC):
2224
"""Base class for middleware that transforms JSON response bodies."""
2325

2426
app: ASGIApp
27+
json_content_type_expr: str = (
28+
r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
29+
)
2530

2631
@abstractmethod
2732
def should_transform_response(self, request: Request) -> bool:
@@ -35,7 +40,7 @@ def should_transform_response(self, request: Request) -> bool:
3540
-------
3641
bool: True if the response should be transformed
3742
"""
38-
pass
43+
return request.headers.get("accept") == "application/json"
3944

4045
@abstractmethod
4146
def transform_json(self, data: Any) -> Any:
@@ -62,16 +67,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6267

6368
start_message: Optional[Message] = None
6469
body = b""
70+
not_json = False
6571

6672
async def process_message(message: Message) -> None:
6773
nonlocal start_message
6874
nonlocal body
69-
75+
nonlocal not_json
7076
if message["type"] == "http.response.start":
7177
# Delay sending start message until we've processed the body
78+
if not re.match(
79+
self.json_content_type_expr,
80+
Headers(scope=message).get("content-type", ""),
81+
):
82+
not_json = True
83+
return await send(message)
7284
start_message = message
7385
return
74-
elif message["type"] != "http.response.body":
86+
elif message["type"] != "http.response.body" or not_json:
7587
return await send(message)
7688

7789
body += message["body"]
@@ -94,9 +106,10 @@ async def process_message(message: Message) -> None:
94106
)
95107

96108
# Transform the JSON body
97-
data = json.loads(body)
98-
transformed = self.transform_json(data)
99-
body = json.dumps(transformed).encode()
109+
if body:
110+
data = json.loads(body)
111+
transformed = self.transform_json(data)
112+
body = json.dumps(transformed).encode()
100113

101114
# Re-compress if necessary
102115
if handler:

0 commit comments

Comments
 (0)