Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:

def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
"""Augment the OpenAPI spec with auth information."""
# Remove any existing servers field from upstream API
# This ensures we don't have conflicting server declarations
if "servers" in data:
del data["servers"]

# Add servers field with root path if root_path is set
if self.root_path:
data["servers"] = [{"url": self.root_path}]
Expand Down
279 changes: 279 additions & 0 deletions tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,282 @@ def test_no_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: st
assert response.status_code == 200
openapi = response.json()
assert "servers" not in openapi


def test_upstream_servers_removed_when_root_path_set(
source_api: FastAPI, source_api_server: str, source_api_responses
):
"""When upstream API has servers field and proxy has root_path, upstream servers are removed and replaced with proxy servers."""
# Configure upstream API to return a servers field
upstream_servers = [{"url": "https://upstream-api.com/stage"}]
# Add the /api endpoint to the responses
source_api_responses["/api"] = {
"GET": {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
"servers": upstream_servers,
}
}

root_path = "/api/v1"
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path=root_path,
)
client = TestClient(app)
response = client.get(root_path + source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()

# Verify upstream servers are removed and replaced with proxy servers
assert "servers" in openapi
assert openapi["servers"] == [{"url": root_path}]
assert openapi["servers"] != upstream_servers


def test_upstream_servers_removed_when_no_root_path(
source_api: FastAPI, source_api_server: str, source_api_responses
):
"""When upstream API has servers field and proxy has no root_path, upstream servers are removed and no servers field is added."""
# Configure upstream API to return a servers field
upstream_servers = [{"url": "https://upstream-api.com/stage"}]
# Add the /api endpoint to the responses
source_api_responses["/api"] = {
"GET": {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
"servers": upstream_servers,
}
}

app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path="", # No root path
)
client = TestClient(app)
response = client.get(source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()

# Verify upstream servers are removed and no servers field is added
assert "servers" not in openapi


def test_no_servers_field_when_upstream_has_none(
source_api: FastAPI, source_api_server: str, source_api_responses
):
"""When upstream API has no servers field and proxy has no root_path, no servers field is added."""
# Configure upstream API to return no servers field
source_api_responses["/api"] = {
"GET": {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
# No servers field
}
}

app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path="", # No root path
)
client = TestClient(app)
response = client.get(source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()

# Verify no servers field is added
assert "servers" not in openapi


def test_multiple_upstream_servers_removed(
source_api: FastAPI, source_api_server: str, source_api_responses
):
"""When upstream API has multiple servers, all are removed and replaced with proxy server."""
# Configure upstream API to return multiple servers
upstream_servers = [
{"url": "https://upstream-api.com/stage"},
{"url": "https://upstream-api.com/prod"},
{
"url": "https://staging.upstream-api.com",
"description": "Staging environment",
},
]
source_api_responses["/api"] = {
"GET": {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
"servers": upstream_servers,
}
}

root_path = "/api/v1"
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path=root_path,
)
client = TestClient(app)
response = client.get(root_path + source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()

# Verify all upstream servers are removed and replaced with proxy server
assert "servers" in openapi
assert openapi["servers"] == [{"url": root_path}]
assert len(openapi["servers"]) == 1
assert openapi["servers"] != upstream_servers


def test_upstream_servers_with_variables_removed(
source_api: FastAPI, source_api_server: str, source_api_responses
):
"""When upstream API has servers with variables, they are removed and replaced with proxy server."""
# Configure upstream API to return servers with variables
upstream_servers = [
{
"url": "https://{environment}.upstream-api.com/{version}",
"variables": {
"environment": {"default": "prod", "enum": ["dev", "staging", "prod"]},
"version": {"default": "v1", "enum": ["v1", "v2"]},
},
}
]
source_api_responses["/api"] = {
"GET": {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
"servers": upstream_servers,
}
}

root_path = "/api/v1"
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path=root_path,
)
client = TestClient(app)
response = client.get(root_path + source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()

# Verify upstream servers with variables are removed and replaced with proxy server
assert "servers" in openapi
assert openapi["servers"] == [{"url": root_path}]
assert len(openapi["servers"]) == 1
assert openapi["servers"] != upstream_servers


def test_malformed_servers_field_handled(
source_api: FastAPI, source_api_server: str, source_api_responses
):
"""When upstream API has malformed servers field, it is removed and replaced with proxy server."""
# Configure upstream API to return malformed servers field
source_api_responses["/api"] = {
"GET": {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
"servers": "invalid_servers_field", # Should be a list
}
}

root_path = "/api/v1"
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path=root_path,
)
client = TestClient(app)
response = client.get(root_path + source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()

# Verify malformed servers field is removed and replaced with proxy server
assert "servers" in openapi
assert openapi["servers"] == [{"url": root_path}]
assert isinstance(openapi["servers"], list)


def test_empty_servers_list_removed(
source_api: FastAPI, source_api_server: str, source_api_responses
):
"""When upstream API has empty servers list, it is removed and replaced with proxy server."""
# Configure upstream API to return empty servers list
source_api_responses["/api"] = {
"GET": {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
"servers": [], # Empty list
}
}

root_path = "/api/v1"
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
root_path=root_path,
)
client = TestClient(app)
response = client.get(root_path + source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()

# Verify empty servers list is removed and replaced with proxy server
assert "servers" in openapi
assert openapi["servers"] == [{"url": root_path}]
assert len(openapi["servers"]) == 1


@pytest.mark.parametrize("root_path", [None, "/api/v1"])
def test_servers_are_replaced_with_proxy_server(root_path: str):
"""Test that verifies upstream servers are replaced with proxy server."""
from unittest.mock import Mock

from stac_auth_proxy.middleware.UpdateOpenApiMiddleware import OpenApiMiddleware

# Test data with upstream servers
test_data = {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
"servers": [
{"url": "https://upstream-api.com/stage"},
{"url": "https://upstream-api.com/prod"},
],
}

# Create middleware instance
middleware = OpenApiMiddleware(
app=Mock(),
openapi_spec_path="/api",
oidc_discovery_url="https://example.com/.well-known/openid-configuration",
private_endpoints={},
public_endpoints={},
default_public=True,
root_path=root_path,
)

# Test the middleware behavior
result = middleware.transform_json(test_data.copy(), Mock())

# Verify that only the proxy server remains
if root_path:
assert "servers" in result
assert len(result["servers"]) == 1
assert result["servers"][0]["url"] == root_path
else:
assert "servers" not in result

# Verify upstream servers are gone
for server in test_data["servers"]:
assert server not in result.get("servers", [])
Loading