Skip to content

Commit b54059b

Browse files
authored
fix(openapi): remove upstream servers (#90)
Ensure that any existing servers field from the upstream API is removed. closes #74
1 parent ece9f73 commit b54059b

File tree

2 files changed

+284
-0
lines changed

2 files changed

+284
-0
lines changed

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:
4747

4848
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
4949
"""Augment the OpenAPI spec with auth information."""
50+
# Remove any existing servers field from upstream API
51+
# This ensures we don't have conflicting server declarations
52+
if "servers" in data:
53+
del data["servers"]
54+
5055
# Add servers field with root path if root_path is set
5156
if self.root_path:
5257
data["servers"] = [{"url": self.root_path}]

tests/test_openapi.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,282 @@ def test_no_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: st
232232
assert response.status_code == 200
233233
openapi = response.json()
234234
assert "servers" not in openapi
235+
236+
237+
def test_upstream_servers_removed_when_root_path_set(
238+
source_api: FastAPI, source_api_server: str, source_api_responses
239+
):
240+
"""When upstream API has servers field and proxy has root_path, upstream servers are removed and replaced with proxy servers."""
241+
# Configure upstream API to return a servers field
242+
upstream_servers = [{"url": "https://upstream-api.com/stage"}]
243+
# Add the /api endpoint to the responses
244+
source_api_responses["/api"] = {
245+
"GET": {
246+
"openapi": "3.0.0",
247+
"info": {"title": "Test API", "version": "1.0.0"},
248+
"paths": {},
249+
"servers": upstream_servers,
250+
}
251+
}
252+
253+
root_path = "/api/v1"
254+
app = app_factory(
255+
upstream_url=source_api_server,
256+
openapi_spec_endpoint=source_api.openapi_url,
257+
root_path=root_path,
258+
)
259+
client = TestClient(app)
260+
response = client.get(root_path + source_api.openapi_url)
261+
assert response.status_code == 200
262+
openapi = response.json()
263+
264+
# Verify upstream servers are removed and replaced with proxy servers
265+
assert "servers" in openapi
266+
assert openapi["servers"] == [{"url": root_path}]
267+
assert openapi["servers"] != upstream_servers
268+
269+
270+
def test_upstream_servers_removed_when_no_root_path(
271+
source_api: FastAPI, source_api_server: str, source_api_responses
272+
):
273+
"""When upstream API has servers field and proxy has no root_path, upstream servers are removed and no servers field is added."""
274+
# Configure upstream API to return a servers field
275+
upstream_servers = [{"url": "https://upstream-api.com/stage"}]
276+
# Add the /api endpoint to the responses
277+
source_api_responses["/api"] = {
278+
"GET": {
279+
"openapi": "3.0.0",
280+
"info": {"title": "Test API", "version": "1.0.0"},
281+
"paths": {},
282+
"servers": upstream_servers,
283+
}
284+
}
285+
286+
app = app_factory(
287+
upstream_url=source_api_server,
288+
openapi_spec_endpoint=source_api.openapi_url,
289+
root_path="", # No root path
290+
)
291+
client = TestClient(app)
292+
response = client.get(source_api.openapi_url)
293+
assert response.status_code == 200
294+
openapi = response.json()
295+
296+
# Verify upstream servers are removed and no servers field is added
297+
assert "servers" not in openapi
298+
299+
300+
def test_no_servers_field_when_upstream_has_none(
301+
source_api: FastAPI, source_api_server: str, source_api_responses
302+
):
303+
"""When upstream API has no servers field and proxy has no root_path, no servers field is added."""
304+
# Configure upstream API to return no servers field
305+
source_api_responses["/api"] = {
306+
"GET": {
307+
"openapi": "3.0.0",
308+
"info": {"title": "Test API", "version": "1.0.0"},
309+
"paths": {},
310+
# No servers field
311+
}
312+
}
313+
314+
app = app_factory(
315+
upstream_url=source_api_server,
316+
openapi_spec_endpoint=source_api.openapi_url,
317+
root_path="", # No root path
318+
)
319+
client = TestClient(app)
320+
response = client.get(source_api.openapi_url)
321+
assert response.status_code == 200
322+
openapi = response.json()
323+
324+
# Verify no servers field is added
325+
assert "servers" not in openapi
326+
327+
328+
def test_multiple_upstream_servers_removed(
329+
source_api: FastAPI, source_api_server: str, source_api_responses
330+
):
331+
"""When upstream API has multiple servers, all are removed and replaced with proxy server."""
332+
# Configure upstream API to return multiple servers
333+
upstream_servers = [
334+
{"url": "https://upstream-api.com/stage"},
335+
{"url": "https://upstream-api.com/prod"},
336+
{
337+
"url": "https://staging.upstream-api.com",
338+
"description": "Staging environment",
339+
},
340+
]
341+
source_api_responses["/api"] = {
342+
"GET": {
343+
"openapi": "3.0.0",
344+
"info": {"title": "Test API", "version": "1.0.0"},
345+
"paths": {},
346+
"servers": upstream_servers,
347+
}
348+
}
349+
350+
root_path = "/api/v1"
351+
app = app_factory(
352+
upstream_url=source_api_server,
353+
openapi_spec_endpoint=source_api.openapi_url,
354+
root_path=root_path,
355+
)
356+
client = TestClient(app)
357+
response = client.get(root_path + source_api.openapi_url)
358+
assert response.status_code == 200
359+
openapi = response.json()
360+
361+
# Verify all upstream servers are removed and replaced with proxy server
362+
assert "servers" in openapi
363+
assert openapi["servers"] == [{"url": root_path}]
364+
assert len(openapi["servers"]) == 1
365+
assert openapi["servers"] != upstream_servers
366+
367+
368+
def test_upstream_servers_with_variables_removed(
369+
source_api: FastAPI, source_api_server: str, source_api_responses
370+
):
371+
"""When upstream API has servers with variables, they are removed and replaced with proxy server."""
372+
# Configure upstream API to return servers with variables
373+
upstream_servers = [
374+
{
375+
"url": "https://{environment}.upstream-api.com/{version}",
376+
"variables": {
377+
"environment": {"default": "prod", "enum": ["dev", "staging", "prod"]},
378+
"version": {"default": "v1", "enum": ["v1", "v2"]},
379+
},
380+
}
381+
]
382+
source_api_responses["/api"] = {
383+
"GET": {
384+
"openapi": "3.0.0",
385+
"info": {"title": "Test API", "version": "1.0.0"},
386+
"paths": {},
387+
"servers": upstream_servers,
388+
}
389+
}
390+
391+
root_path = "/api/v1"
392+
app = app_factory(
393+
upstream_url=source_api_server,
394+
openapi_spec_endpoint=source_api.openapi_url,
395+
root_path=root_path,
396+
)
397+
client = TestClient(app)
398+
response = client.get(root_path + source_api.openapi_url)
399+
assert response.status_code == 200
400+
openapi = response.json()
401+
402+
# Verify upstream servers with variables are removed and replaced with proxy server
403+
assert "servers" in openapi
404+
assert openapi["servers"] == [{"url": root_path}]
405+
assert len(openapi["servers"]) == 1
406+
assert openapi["servers"] != upstream_servers
407+
408+
409+
def test_malformed_servers_field_handled(
410+
source_api: FastAPI, source_api_server: str, source_api_responses
411+
):
412+
"""When upstream API has malformed servers field, it is removed and replaced with proxy server."""
413+
# Configure upstream API to return malformed servers field
414+
source_api_responses["/api"] = {
415+
"GET": {
416+
"openapi": "3.0.0",
417+
"info": {"title": "Test API", "version": "1.0.0"},
418+
"paths": {},
419+
"servers": "invalid_servers_field", # Should be a list
420+
}
421+
}
422+
423+
root_path = "/api/v1"
424+
app = app_factory(
425+
upstream_url=source_api_server,
426+
openapi_spec_endpoint=source_api.openapi_url,
427+
root_path=root_path,
428+
)
429+
client = TestClient(app)
430+
response = client.get(root_path + source_api.openapi_url)
431+
assert response.status_code == 200
432+
openapi = response.json()
433+
434+
# Verify malformed servers field is removed and replaced with proxy server
435+
assert "servers" in openapi
436+
assert openapi["servers"] == [{"url": root_path}]
437+
assert isinstance(openapi["servers"], list)
438+
439+
440+
def test_empty_servers_list_removed(
441+
source_api: FastAPI, source_api_server: str, source_api_responses
442+
):
443+
"""When upstream API has empty servers list, it is removed and replaced with proxy server."""
444+
# Configure upstream API to return empty servers list
445+
source_api_responses["/api"] = {
446+
"GET": {
447+
"openapi": "3.0.0",
448+
"info": {"title": "Test API", "version": "1.0.0"},
449+
"paths": {},
450+
"servers": [], # Empty list
451+
}
452+
}
453+
454+
root_path = "/api/v1"
455+
app = app_factory(
456+
upstream_url=source_api_server,
457+
openapi_spec_endpoint=source_api.openapi_url,
458+
root_path=root_path,
459+
)
460+
client = TestClient(app)
461+
response = client.get(root_path + source_api.openapi_url)
462+
assert response.status_code == 200
463+
openapi = response.json()
464+
465+
# Verify empty servers list is removed and replaced with proxy server
466+
assert "servers" in openapi
467+
assert openapi["servers"] == [{"url": root_path}]
468+
assert len(openapi["servers"]) == 1
469+
470+
471+
@pytest.mark.parametrize("root_path", [None, "/api/v1"])
472+
def test_servers_are_replaced_with_proxy_server(root_path: str):
473+
"""Test that verifies upstream servers are replaced with proxy server."""
474+
from unittest.mock import Mock
475+
476+
from stac_auth_proxy.middleware.UpdateOpenApiMiddleware import OpenApiMiddleware
477+
478+
# Test data with upstream servers
479+
test_data = {
480+
"openapi": "3.0.0",
481+
"info": {"title": "Test API", "version": "1.0.0"},
482+
"paths": {},
483+
"servers": [
484+
{"url": "https://upstream-api.com/stage"},
485+
{"url": "https://upstream-api.com/prod"},
486+
],
487+
}
488+
489+
# Create middleware instance
490+
middleware = OpenApiMiddleware(
491+
app=Mock(),
492+
openapi_spec_path="/api",
493+
oidc_discovery_url="https://example.com/.well-known/openid-configuration",
494+
private_endpoints={},
495+
public_endpoints={},
496+
default_public=True,
497+
root_path=root_path,
498+
)
499+
500+
# Test the middleware behavior
501+
result = middleware.transform_json(test_data.copy(), Mock())
502+
503+
# Verify that only the proxy server remains
504+
if root_path:
505+
assert "servers" in result
506+
assert len(result["servers"]) == 1
507+
assert result["servers"][0]["url"] == root_path
508+
else:
509+
assert "servers" not in result
510+
511+
# Verify upstream servers are gone
512+
for server in test_data["servers"]:
513+
assert server not in result.get("servers", [])

0 commit comments

Comments
 (0)