Skip to content

Commit 8c17e21

Browse files
committed
Fix OpenAPI augmentation to correctly mark paths as private or public
1 parent 3c45287 commit 8c17e21

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

src/stac_auth_proxy/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
3838
OpenApiMiddleware,
3939
openapi_spec_path=settings.openapi_spec_endpoint,
4040
oidc_config_url=str(settings.oidc_discovery_url),
41+
public_endpoints=settings.public_endpoints,
4142
private_endpoints=settings.private_endpoints,
4243
default_public=settings.default_public,
4344
)

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class OpenApiMiddleware:
3131
openapi_spec_path: str
3232
oidc_config_url: str
3333
private_endpoints: EndpointMethods
34+
public_endpoints: EndpointMethods
3435
default_public: bool
3536
oidc_auth_scheme_name: str = "oidcAuth"
3637

@@ -97,6 +98,23 @@ async def augment_oidc_spec(message: Message):
9798

9899
return await self.app(scope, receive, augment_oidc_spec)
99100

101+
# def augment_spec(self, openapi_spec) -> dict[str, Any]:
102+
# """Augment the OpenAPI spec with auth information."""
103+
# components = openapi_spec.setdefault("components", {})
104+
# securitySchemes = components.setdefault("securitySchemes", {})
105+
# securitySchemes[self.oidc_auth_scheme_name] = {
106+
# "type": "openIdConnect",
107+
# "openIdConnectUrl": self.oidc_config_url,
108+
# }
109+
# for path, method_config in openapi_spec["paths"].items():
110+
# for method, config in method_config.items():
111+
# for private_method in self.private_endpoints.get(path, []):
112+
# if method.casefold() == private_method.casefold():
113+
# config.setdefault("security", []).append(
114+
# {self.oidc_auth_scheme_name: []}
115+
# )
116+
# return openapi_spec
117+
100118
def augment_spec(self, openapi_spec) -> dict[str, Any]:
101119
"""Augment the OpenAPI spec with auth information."""
102120
components = openapi_spec.setdefault("components", {})
@@ -107,9 +125,24 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]:
107125
}
108126
for path, method_config in openapi_spec["paths"].items():
109127
for method, config in method_config.items():
110-
for private_method in self.private_endpoints.get(path, []):
111-
if method.casefold() == private_method.casefold():
112-
config.setdefault("security", []).append(
113-
{self.oidc_auth_scheme_name: []}
114-
)
128+
requires_auth = (
129+
self.path_matches(path, method, self.private_endpoints)
130+
if self.default_public
131+
else not self.path_matches(path, method, self.public_endpoints)
132+
)
133+
if requires_auth:
134+
config.setdefault("security", []).append(
135+
{self.oidc_auth_scheme_name: []}
136+
)
115137
return openapi_spec
138+
139+
@staticmethod
140+
def path_matches(path: str, method: str, endpoints: dict[str, list[str]]) -> bool:
141+
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
142+
for pattern, endpoint_methods in endpoints.items():
143+
if not re.match(pattern, path):
144+
continue
145+
for endpoint_method in endpoint_methods:
146+
if method.casefold() == endpoint_method.casefold():
147+
return True
148+
return False

0 commit comments

Comments
 (0)