Skip to content

Commit 855183a

Browse files
authored
fix: ensure OPTIONS requests are sent upstream without auth check (#76)
Allow OPTIONS requests through without performing auth checks. Also, ensure that we don't add auth requirements to OPTIONS endpoints when augment OpenAPI spec. Finally, Claude helped to write tests and recommended we catch `jwt.exceptions.PyJWKClientError` errors when validating tokens (this came up when it added a test with invalid tokens) Closes #75
1 parent 8a09873 commit 855183a

File tree

6 files changed

+227
-13
lines changed

6 files changed

+227
-13
lines changed

src/stac_auth_proxy/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ async def lifespan(app: FastAPI):
104104
upstream=str(settings.upstream_url),
105105
override_host=settings.override_host,
106106
).proxy_request,
107-
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
107+
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
108108
)
109109

110110
#

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8585
return await self.app(scope, receive, send)
8686

8787
request = Request(scope)
88+
89+
# Skip authentication for OPTIONS requests, https://fetch.spec.whatwg.org/#cors-protocol-and-credentials
90+
if request.method == "OPTIONS":
91+
return await self.app(scope, receive, send)
92+
8893
match = find_match(
8994
request.url.path,
9095
request.method,
@@ -148,7 +153,11 @@ def validate_token(
148153
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
149154
audience=self.allowed_jwt_audiences,
150155
)
151-
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
156+
except (
157+
jwt.exceptions.InvalidTokenError,
158+
jwt.exceptions.DecodeError,
159+
jwt.exceptions.PyJWKClientError,
160+
) as e:
152161
logger.error("InvalidTokenError: %r", e)
153162
raise HTTPException(
154163
status_code=status.HTTP_401_UNAUTHORIZED,

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
6262
# Add security to private endpoints
6363
for path, method_config in data["paths"].items():
6464
for method, config in method_config.items():
65+
if method == "options":
66+
# OPTIONS requests are not authenticated, https://fetch.spec.whatwg.org/#cors-protocol-and-credentials
67+
continue
6568
match = find_match(
6669
path,
6770
method,

tests/conftest.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,35 +87,50 @@ def source_api():
8787

8888
# Default responses for each endpoint
8989
default_responses = {
90-
"/": {"GET": {"id": "Response from GET@"}},
91-
"/conformance": {"GET": {"conformsTo": ["http://example.com/conformance"]}},
92-
"/queryables": {"GET": {"queryables": {}}},
90+
"/": {
91+
"GET": {"id": "Response from GET@"},
92+
"OPTIONS": {"id": "Response from OPTIONS@"},
93+
},
94+
"/conformance": {
95+
"GET": {"conformsTo": ["http://example.com/conformance"]},
96+
"OPTIONS": {"id": "Response from OPTIONS@"},
97+
},
98+
"/queryables": {
99+
"GET": {"queryables": {}},
100+
"OPTIONS": {"id": "Response from OPTIONS@"},
101+
},
93102
"/search": {
94103
"GET": {"type": "FeatureCollection", "features": []},
95104
"POST": {"type": "FeatureCollection", "features": []},
105+
"OPTIONS": {"id": "Response from OPTIONS@"},
96106
},
97107
"/collections": {
98108
"GET": {"collections": []},
99109
"POST": {"id": "Response from POST@"},
110+
"OPTIONS": {"id": "Response from OPTIONS@"},
100111
},
101112
"/collections/{collection_id}": {
102113
"GET": {"id": "Response from GET@"},
103114
"PUT": {"id": "Response from PUT@"},
104115
"PATCH": {"id": "Response from PATCH@"},
105116
"DELETE": {"id": "Response from DELETE@"},
117+
"OPTIONS": {"id": "Response from OPTIONS@"},
106118
},
107119
"/collections/{collection_id}/items": {
108120
"GET": {"type": "FeatureCollection", "features": []},
109121
"POST": {"id": "Response from POST@"},
122+
"OPTIONS": {"id": "Response from OPTIONS@"},
110123
},
111124
"/collections/{collection_id}/items/{item_id}": {
112125
"GET": {"id": "Response from GET@"},
113126
"PUT": {"id": "Response from PUT@"},
114127
"PATCH": {"id": "Response from PATCH@"},
115128
"DELETE": {"id": "Response from DELETE@"},
129+
"OPTIONS": {"id": "Response from OPTIONS@"},
116130
},
117131
"/collections/{collection_id}/bulk_items": {
118132
"POST": {"id": "Response from POST@"},
133+
"OPTIONS": {"id": "Response from OPTIONS@"},
119134
},
120135
}
121136

tests/test_authn.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,178 @@ def test_scopes(
167167
)
168168
expected_status_code = 200 if expected_permitted else 401
169169
assert response.status_code == expected_status_code
170+
171+
172+
@pytest.mark.parametrize(
173+
"path,default_public,private_endpoints",
174+
[
175+
("/", False, {}),
176+
("/collections", False, {}),
177+
("/search", False, {}),
178+
("/collections", True, {r"^/collections$": [("POST", "collection:create")]}),
179+
("/search", True, {r"^/search$": [("POST", "search:write")]}),
180+
(
181+
"/collections/example-collection/items",
182+
True,
183+
{r"^/collections/.*/items$": [("POST", "item:create")]},
184+
),
185+
],
186+
)
187+
def test_options_bypass_auth(
188+
path, default_public, private_endpoints, source_api_server
189+
):
190+
"""OPTIONS requests should bypass authentication regardless of endpoint configuration."""
191+
test_app = app_factory(
192+
upstream_url=source_api_server,
193+
default_public=default_public,
194+
private_endpoints=private_endpoints,
195+
)
196+
client = TestClient(test_app)
197+
response = client.options(path)
198+
assert response.status_code == 200, "OPTIONS request should bypass authentication"
199+
200+
201+
@pytest.mark.parametrize(
202+
"path,method,default_public,private_endpoints,expected_status",
203+
[
204+
# Test that non-OPTIONS requests still require auth when endpoints are private
205+
("/collections", "GET", False, {}, 403),
206+
("/collections", "POST", False, {}, 403),
207+
("/search", "GET", False, {}, 403),
208+
# Test that OPTIONS requests bypass auth even when endpoints are private
209+
("/collections", "OPTIONS", False, {}, 200),
210+
("/search", "OPTIONS", False, {}, 200),
211+
# Test with specific private endpoint configurations
212+
(
213+
"/collections",
214+
"POST",
215+
True,
216+
{r"^/collections$": [("POST", "collection:create")]},
217+
403,
218+
),
219+
(
220+
"/collections",
221+
"OPTIONS",
222+
True,
223+
{r"^/collections$": [("POST", "collection:create")]},
224+
200,
225+
),
226+
],
227+
)
228+
def test_options_vs_other_methods_auth_behavior(
229+
path, method, default_public, private_endpoints, expected_status, source_api_server
230+
):
231+
"""Compare authentication behavior between OPTIONS and other HTTP methods."""
232+
test_app = app_factory(
233+
upstream_url=source_api_server,
234+
default_public=default_public,
235+
private_endpoints=private_endpoints,
236+
)
237+
client = TestClient(test_app)
238+
response = client.request(method=method, url=path, headers={})
239+
assert response.status_code == expected_status
240+
241+
242+
@pytest.mark.parametrize(
243+
"path,method,default_public,private_endpoints,expected_status",
244+
[
245+
# Test that requests with valid auth succeed
246+
("/collections", "GET", False, {}, 200),
247+
("/collections", "POST", False, {}, 200),
248+
("/search", "GET", False, {}, 200),
249+
("/collections", "OPTIONS", False, {}, 200),
250+
("/search", "OPTIONS", False, {}, 200),
251+
# Test with specific private endpoint configurations
252+
(
253+
"/collections",
254+
"POST",
255+
True,
256+
{r"^/collections$": [("POST", "collection:create")]},
257+
200,
258+
),
259+
(
260+
"/collections",
261+
"OPTIONS",
262+
True,
263+
{r"^/collections$": [("POST", "collection:create")]},
264+
200,
265+
),
266+
],
267+
)
268+
def test_options_vs_other_methods_with_valid_auth(
269+
path,
270+
method,
271+
default_public,
272+
private_endpoints,
273+
expected_status,
274+
source_api_server,
275+
token_builder,
276+
):
277+
"""Compare authentication behavior between OPTIONS and other HTTP methods with valid auth."""
278+
test_app = app_factory(
279+
upstream_url=source_api_server,
280+
default_public=default_public,
281+
private_endpoints=private_endpoints,
282+
)
283+
valid_auth_token = token_builder({"scope": "collection:create"})
284+
client = TestClient(test_app)
285+
response = client.request(
286+
method=method,
287+
url=path,
288+
headers={"Authorization": f"Bearer {valid_auth_token}"},
289+
)
290+
assert response.status_code == expected_status
291+
292+
293+
@pytest.mark.parametrize(
294+
"invalid_token,expected_status",
295+
[
296+
("Bearer invalid-token", 401),
297+
(
298+
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
299+
401,
300+
),
301+
("InvalidFormat", 401),
302+
("Bearer", 401),
303+
("", 403), # No auth header returns 403, not 401
304+
],
305+
)
306+
def test_with_invalid_tokens_fails(invalid_token, expected_status, source_api_server):
307+
"""GET requests should fail with invalid or malformed tokens."""
308+
test_app = app_factory(
309+
upstream_url=source_api_server,
310+
default_public=False, # All endpoints private
311+
private_endpoints={},
312+
)
313+
client = TestClient(test_app)
314+
response = client.get("/collections", headers={"Authorization": invalid_token})
315+
assert (
316+
response.status_code == expected_status
317+
), f"GET request should fail with token: {invalid_token}"
318+
319+
response = client.options("/collections", headers={"Authorization": invalid_token})
320+
assert (
321+
response.status_code == 200
322+
), f"OPTIONS request should succeed with token: {invalid_token}"
323+
324+
325+
def test_options_requests_with_cors_headers(source_api_server):
326+
"""OPTIONS requests should work properly with CORS headers."""
327+
test_app = app_factory(
328+
upstream_url=source_api_server,
329+
default_public=False, # All endpoints private
330+
private_endpoints={},
331+
)
332+
client = TestClient(test_app)
333+
334+
# Test OPTIONS request with CORS headers
335+
cors_headers = {
336+
"Origin": "https://example.com",
337+
"Access-Control-Request-Method": "POST",
338+
"Access-Control-Request-Headers": "Content-Type,Authorization",
339+
}
340+
341+
response = client.options("/collections", headers=cors_headers)
342+
assert (
343+
response.status_code == 200
344+
), "OPTIONS request with CORS headers should succeed"

tests/test_openapi.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_oidc_in_openapi_spec_public_endpoints(
129129
source_api: FastAPI, source_api_server: str
130130
):
131131
"""When OpenAPI spec endpoint is set & endpoints are marked public, those endpoints are not marked private in the spec."""
132-
public = {r"^/queryables$": ["GET"], r"^/api": ["GET"]}
132+
public = {r"^/queryables$": ["GET"], r"^/api$": ["GET"]}
133133
app = app_factory(
134134
upstream_url=source_api_server,
135135
openapi_spec_endpoint=source_api.openapi_url,
@@ -140,17 +140,29 @@ def test_oidc_in_openapi_spec_public_endpoints(
140140

141141
openapi = client.get(source_api.openapi_url).raise_for_status().json()
142142

143-
expected_auth = {"/queryables": ["GET"]}
143+
expected_required_auth = {"/queryables": ["GET"]}
144144
for path, method_config in openapi["paths"].items():
145145
for method, config in method_config.items():
146146
security = config.get("security")
147+
148+
if method == "options":
149+
assert (
150+
not security
151+
), f"OPTIONS {path} requests should not require authentication"
152+
continue
153+
147154
if security:
148-
assert path not in expected_auth
149-
else:
150-
assert path in expected_auth
151-
assert any(
152-
method.casefold() == m.casefold() for m in expected_auth[path]
153-
)
155+
assert (
156+
path not in expected_required_auth
157+
), f"Path {path} should not require authentication"
158+
continue
159+
160+
assert (
161+
path in expected_required_auth
162+
), f"Path {path} should require authentication"
163+
assert any(
164+
method.casefold() == m.casefold() for m in expected_required_auth[path]
165+
)
154166

155167

156168
def test_auth_scheme_name_override(source_api: FastAPI, source_api_server: str):

0 commit comments

Comments
 (0)