Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 59 additions & 37 deletions stac_api/runtime/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import uuid
from unittest.mock import MagicMock, patch

import pypgstac
import pytest
from httpx import ASGITransport, AsyncClient
from pystac import STACObjectType
Expand Down Expand Up @@ -257,41 +258,44 @@ def override_jwks_client():
return "https://example.com/jwks"


@pytest.fixture(autouse=True)
def mock_auth():
"""Mock the stac_auth_proxy to bypass actual OIDC calls while preserving normal operation."""
# Define the mock JWT payload that validate_token should return
mock_jwt_payload = {
@pytest.fixture
def mock_jwt_payload(request):
"""
Mock function to return a mock JWT payload with the requested scopes. Supports
customization of scopes like so:

```py
@pytest.mark.parametrize(
"mock_jwt_payload", ["openid stac:collection:delete"], indirect=True
)
```

Returns:
dict: A mock JWT payload with the requested scopes.
"""
return {
"sub": "test-user",
"preferred_username": "test-user",
"aud": "account",
"iss": "https://example.com",
"iat": 1700000000,
"exp": 1700003600,
"scope": "openid stac:item:create stac:item:update stac:item:delete stac:collection:create stac:collection:update stac:collection:delete",
"scope": request.param or "openid",
}

# Import the class to patch its method
from stac_auth_proxy.middleware import EnforceAuthMiddleware

# Mock the OidcService to prevent OIDC discovery requests
with patch(
"stac_auth_proxy.middleware.EnforceAuthMiddleware.OidcService"
) as mock_oidc_service:
# Create a mock OIDC service that doesn't make network calls
mock_oidc_instance = MagicMock()
mock_oidc_instance.metadata = {
"jwks_uri": "https://example.com/jwks",
"issuer": "https://example.com",
}
mock_oidc_instance.jwks_client = MagicMock()
mock_oidc_service.return_value = mock_oidc_instance

# Mock the EnforceAuthMiddleware validate_token method on the class
with patch.object(
EnforceAuthMiddleware, "validate_token", return_value=mock_jwt_payload
):
yield
@pytest.fixture(autouse=True)
def mock_auth(mock_jwt_payload):
"""Mock the stac_auth_proxy to bypass actual OIDC calls while preserving normal operation."""
# Define the mock JWT payload that validate_token should return
with patch("jwt.decode", return_value=mock_jwt_payload) as mock_decode:
# Also need to mock the JWKS client to prevent network calls
with patch("jwt.PyJWKClient") as mock_jwks_client:
mock_jwks_instance = MagicMock()
mock_jwks_instance.get_signing_key_from_jwt.return_value.key = "mock-key"
mock_jwks_client.return_value = mock_jwks_instance

yield mock_jwt_payload


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -415,18 +419,36 @@ async def collection_in_db(app: FastAPI, api_client, valid_stac_collection):
This fixture posts a valid collection before a test runs and yields
the collection ID.
"""
# Create the collection
response = await api_client.post(
f"{app.root_path}/collections", json=valid_stac_collection
)
from src.app import app

# Ensure the setup was successful before the test proceeds
# The setup is successful if the collection was created (201) or if it
# already existed (409). Any other status code is a failure.
assert response.status_code in [201, 409]
# TODO: Use pypgstac to load collection to avoid need to mock auth
# async with app.state.writepool.acquire() as connection:
# # Open a transaction.
# async with connection.transaction():
# # Run the query passing the request argument.
# result = await connection.fetchval("select 2 ^ $1", power)
# with app.state.writepool as conn:
# with conn.cursor() as cursor:
# cursor.execute(
# "SELECT id FROM collections WHERE id = %s",
# (valid_stac_collection["id"],),
# )
# result = cursor.fetchone()

with patch.object(EnforceAuthMiddleware, "oidc_config", return_value={}):
with patch.object(EnforceAuthMiddleware, "validate_token", return_value={}):
# Create the collection
response = await api_client.post(
f"{app.root_path}/collections", json=valid_stac_collection
)

# Ensure the setup was successful before the test proceeds
assert response.status_code == 201

yield valid_stac_collection["id"]

await api_client.delete(
f"{app.root_path}/collections/{valid_stac_collection['id']}"
)
with patch.object(EnforceAuthMiddleware, "oidc_config", return_value={}):
with patch.object(EnforceAuthMiddleware, "validate_token", return_value={}):
await api_client.delete(
f"{app.root_path}/collections/{valid_stac_collection['id']}"
)
Loading
Loading