Skip to content

Commit 28424ff

Browse files
authored
feat: add eoapi-auth-util to stac api (#467)
1 parent 9412dc2 commit 28424ff

File tree

10 files changed

+80
-135
lines changed

10 files changed

+80
-135
lines changed

.example.env

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ VEDA_CLIENT_ID=
3232
VEDA_CLIENT_SECRET=secret
3333
VEDA_DATA_ACCESS_ROLE_ARN=
3434
VEDA_COGNITO_DOMAIN=
35+
VEDA_OPENID_CONFIGURATION_URL=
36+
VEDA_KEYCLOAK_CLIENT_ID=
3537

3638
STAC_BROWSER_BUCKET=
3739
STAC_URL=

local/Dockerfile.stac

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ RUN pip install --upgrade pip
1212
RUN pip install boto3
1313

1414
COPY stac_api/runtime /tmp/stac
15-
16-
COPY common/auth /tmp/stac/common/auth
17-
RUN pip install /tmp/stac/common/auth
1815
RUN pip install /tmp/stac
1916
RUN rm -rf /tmp/stac
2017

stac_api/infrastructure/config.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,9 @@ class vedaSTACSettings(BaseSettings):
4545
description="Description of the STAC Catalog",
4646
)
4747

48-
userpool_id: Optional[str] = Field(
49-
None, description="The Cognito Userpool used for authentication"
50-
)
51-
cognito_domain: Optional[AnyHttpUrl] = Field(
52-
None,
53-
description="The base url of the Cognito domain for authorization and token urls",
54-
)
55-
client_id: Optional[str] = Field(None, description="The Cognito APP client ID")
56-
client_secret: Optional[str] = Field(
57-
"", description="The Cognito APP client secret"
48+
keycloak_client_id: Optional[str] = Field(None, description="Auth client ID")
49+
openid_configuration_url: Optional[AnyHttpUrl] = Field(
50+
None, description="OpenID config url"
5851
)
5952
stac_enable_transactions: bool = Field(
6053
False, description="Whether to enable transactions endpoints"
@@ -69,10 +62,10 @@ def check_transaction_fields(cls, values):
6962
"""
7063
Validates the existence of auth env vars in case stac_enable_transactions is True
7164
"""
72-
if values.get("stac_enable_transactions"):
65+
if values.get("stac_enable_transactions") == "True":
7366
missing_fields = [
7467
field
75-
for field in ["userpool_id", "cognito_domain", "client_id"]
68+
for field in ["keycloak_client_id", "openid_configuration_url"]
7669
if not values.get(field)
7770
]
7871
if missing_fields:

stac_api/infrastructure/construct.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ def __init__(
4242
"VEDA_STAC_PROJECT_DESCRIPTION": veda_stac_settings.project_description,
4343
"VEDA_STAC_ROOT_PATH": veda_stac_settings.stac_root_path,
4444
"VEDA_STAC_STAGE": stage,
45-
"VEDA_STAC_USERPOOL_ID": veda_stac_settings.userpool_id,
46-
"VEDA_STAC_CLIENT_ID": veda_stac_settings.client_id,
47-
"VEDA_STAC_COGNITO_DOMAIN": str(veda_stac_settings.cognito_domain),
45+
"VEDA_STAC_CLIENT_ID": veda_stac_settings.keycloak_client_id,
46+
"VEDA_STAC_OPENID_CONFIGURATION_URL": str(
47+
veda_stac_settings.openid_configuration_url
48+
),
4849
"VEDA_STAC_ENABLE_TRANSACTIONS": str(
4950
veda_stac_settings.stac_enable_transactions
5051
),

stac_api/runtime/Dockerfile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ WORKDIR /tmp
55
COPY stac_api/runtime /tmp/stac
66

77
RUN pip install "mangum" "plpygis>=0.2.1" /tmp/stac -t /asset --no-binary pydantic
8-
COPY common/auth /tmp/stac/common/auth
9-
RUN pip install /tmp/stac/common/auth -t /asset
108
RUN rm -rf /tmp/stac
119

1210
# Reduce package size and remove useless files

stac_api/runtime/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"aws_xray_sdk>=2.6.0,<3",
2121
"pystac[validation]==1.10.1",
2222
"pydantic>2",
23+
"eoapi-auth-utils==0.3.0",
2324
]
2425

2526
extra_reqs = {

stac_api/runtime/src/app.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from src.extension import TiTilerExtension
1313

1414
from fastapi import APIRouter, FastAPI
15-
from fastapi.params import Depends
1615
from fastapi.responses import ORJSONResponse
1716
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
1817
from starlette.middleware import Middleware
@@ -25,9 +24,10 @@
2524
from .api import VedaStacApi
2625
from .core import VedaCrudClient
2726
from .monitoring import LoggerRouteHandler, logger, metrics, tracer
28-
from .routes import add_route_dependencies
2927
from .validation import ValidationMiddleware
3028

29+
from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings
30+
3131
try:
3232
from importlib.resources import files as resources_files # type: ignore
3333
except ImportError:
@@ -38,6 +38,7 @@
3838
templates = Jinja2Templates(directory=str(resources_files(__package__) / "templates")) # type: ignore
3939

4040
tiles_settings = TilesApiSettings()
41+
auth_settings = OpenIdConnectSettings(_env_prefix="VEDA_STAC_")
4142

4243

4344
@asynccontextmanager
@@ -56,11 +57,12 @@ async def lifespan(app: FastAPI):
5657
root_path=api_settings.root_path,
5758
swagger_ui_init_oauth=(
5859
{
59-
"appName": "Cognito",
60-
"clientId": api_settings.client_id,
60+
"appName": "STAC API",
61+
"clientId": auth_settings.client_id,
6162
"usePkceWithAuthorizationCodeGrant": True,
63+
"scopes": "openid stac:item:create stac:item:update stac:item:delete stac:collection:create stac:collection:update stac:collection:delete",
6264
}
63-
if api_settings.client_id
65+
if auth_settings.client_id
6466
else {}
6567
),
6668
lifespan=lifespan,
@@ -88,41 +90,35 @@ async def lifespan(app: FastAPI):
8890
allow_headers=["*"],
8991
)
9092

91-
if api_settings.enable_transactions:
92-
from veda_auth import VedaAuth
93-
94-
auth = VedaAuth(api_settings)
95-
# Require auth for all endpoints that create, modify or delete data.
96-
add_route_dependencies(
97-
app.router.routes,
98-
[
99-
{"path": "/collections", "method": "POST", "type": "http"},
100-
{"path": "/collections/{collectionId}", "method": "PUT", "type": "http"},
101-
{"path": "/collections/{collectionId}", "method": "DELETE", "type": "http"},
102-
{
103-
"path": "/collections/{collectionId}/items",
104-
"method": "POST",
105-
"type": "http",
106-
},
107-
{
108-
"path": "/collections/{collectionId}/items/{itemId}",
109-
"method": "PUT",
110-
"type": "http",
111-
},
112-
{
113-
"path": "/collections/{collectionId}/items/{itemId}",
114-
"method": "DELETE",
115-
"type": "http",
116-
},
117-
{
118-
"path": "/collections/{collectionId}/bulk_items",
119-
"method": "POST",
120-
"type": "http",
121-
},
122-
],
123-
[Depends(auth.validated_token)],
93+
if api_settings.enable_transactions and auth_settings.client_id:
94+
oidc_auth = OpenIdConnectAuth(
95+
openid_configuration_url=auth_settings.openid_configuration_url,
96+
allowed_jwt_audiences="account",
12497
)
12598

99+
restricted_prefixes_methods = {
100+
"/collections": [("POST", "stac:collection:create")],
101+
"/collections/{collection_id}": [
102+
("PUT", "stac:collection:update"),
103+
("DELETE", "stac:collection:delete"),
104+
],
105+
"/collections/{collection_id}/items": [("POST", "stac:item:create")],
106+
"/collections/{collection_id}/items/{item_id}": [
107+
("PUT", "stac:item:update"),
108+
("DELETE", "stac:item:delete"),
109+
],
110+
"/collections/{collection_id}/bulk_items": [("POST", "stac:item:create")],
111+
}
112+
113+
for route in app.router.routes:
114+
method_scopes = restricted_prefixes_methods.get(route.path)
115+
if not method_scopes:
116+
continue
117+
for method, scope in method_scopes:
118+
if method not in route.methods:
119+
continue
120+
oidc_auth.apply_auth_dependencies(route, required_token_scopes=[scope])
121+
126122
if tiles_settings.titiler_endpoint:
127123
# Register to the TiTiler extension to the api
128124
extension = TiTilerExtension()

stac_api/runtime/src/config.py

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Optional
88

99
import boto3
10-
from pydantic import AnyHttpUrl, Field, field_validator, model_validator
10+
from pydantic import AnyHttpUrl, Field, field_validator
1111
from pydantic_settings import BaseSettings, SettingsConfigDict
1212

1313
from fastapi.responses import ORJSONResponse
@@ -63,55 +63,14 @@ class _ApiSettings(BaseSettings):
6363
root_path: Optional[str] = None
6464
pgstac_secret_arn: Optional[str] = None
6565
stage: Optional[str] = None
66-
67-
userpool_id: Optional[str] = Field(
68-
"", description="The Cognito Userpool used for authentication"
69-
)
70-
cognito_domain: Optional[AnyHttpUrl] = Field(
71-
None,
72-
description="The base url of the Cognito domain for authorization and token urls",
73-
)
74-
client_id: Optional[str] = Field(None, description="The Cognito APP client ID")
75-
client_secret: Optional[str] = Field(
76-
"", description="The Cognito APP client secret"
66+
client_id: Optional[str] = Field(None, description="Auth client ID")
67+
openid_configuration_url: Optional[AnyHttpUrl] = Field(
68+
None, description="OpenID config url"
7769
)
7870
enable_transactions: bool = Field(
7971
False, description="Whether to enable transactions"
8072
)
8173

82-
@model_validator(mode="before")
83-
def check_transaction_fields(cls, values):
84-
enable_transactions = values.get("enable_transactions")
85-
86-
if enable_transactions:
87-
missing_fields = [
88-
field
89-
for field in ["userpool_id", "cognito_domain", "client_id"]
90-
if not values.get(field)
91-
]
92-
if missing_fields:
93-
raise ValueError(
94-
f"When 'enable_transactions' is True, the following fields must be provided: {', '.join(missing_fields)}"
95-
)
96-
return values
97-
98-
@property
99-
def jwks_url(self) -> AnyHttpUrl:
100-
"""JWKS url"""
101-
if self.userpool_id:
102-
region = self.userpool_id.split("_")[0]
103-
return f"https://cognito-idp.{region}.amazonaws.com/{self.userpool_id}/.well-known/jwks.json"
104-
105-
@property
106-
def cognito_authorization_url(self) -> AnyHttpUrl:
107-
"""Cognito user pool authorization url"""
108-
return f"{self.cognito_domain}/oauth2/authorize"
109-
110-
@property
111-
def cognito_token_url(self) -> AnyHttpUrl:
112-
"""Cognito user pool token and refresh url"""
113-
return f"{self.cognito_domain}/oauth2/token"
114-
11574
@field_validator("cors_origins")
11675
@classmethod
11776
def parse_cors_origin(cls, v):

stac_api/runtime/src/routes.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

stac_api/runtime/tests/conftest.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
import os
10+
from unittest.mock import MagicMock, patch
1011

1112
import pytest
1213
from httpx import ASGITransport, AsyncClient
@@ -224,13 +225,11 @@ def test_environ():
224225
os.environ["AWS_SECURITY_TOKEN"] = "testing"
225226
os.environ["AWS_SESSION_TOKEN"] = "testing"
226227
os.environ["AWS_REGION"] = "us-west-2"
227-
os.environ["VEDA_STAC_USERPOOL_ID"] = "us-west-2_FAKEUSERPOOL"
228228
os.environ["VEDA_STAC_CLIENT_ID"] = "Xdjkfghadsfkdsadfjas"
229-
os.environ["VEDA_STAC_CLIENT_SECRET"] = "dsakfjdsalfkjadslfjalksfj"
230229
os.environ[
231-
"VEDA_STAC_COGNITO_DOMAIN"
232-
] = "https://fake.auth.us-west-2.amazoncognito.com"
233-
os.environ["VEDA_STAC_ENABLE_TRANSACTIONS"] = "TRUE"
230+
"VEDA_STAC_OPENID_CONFIGURATION_URL"
231+
] = "https://example.com/.well-known/openid-configuration"
232+
os.environ["VEDA_STAC_ENABLE_TRANSACTIONS"] = "True"
234233

235234
# Config mocks
236235
os.environ["POSTGRES_USER"] = "username"
@@ -251,6 +250,28 @@ def override_validated_token():
251250
return "fake_token"
252251

253252

253+
def override_jwks_client():
254+
"""
255+
Mock function to override jwks uri.
256+
257+
Returns:
258+
str: A fake jwks url.
259+
"""
260+
return "https://example.com/jwks"
261+
262+
263+
@pytest.fixture(autouse=True)
264+
def mock_auth():
265+
"""Mock the OpenIdConnectAuth class to bypass actual OIDC calls."""
266+
with patch("eoapi.auth_utils.OpenIdConnectAuth") as mock:
267+
# Create a mock instance
268+
mock_instance = MagicMock()
269+
mock_instance.valid_token_dependency = override_validated_token
270+
mock_instance.jwks_client = override_jwks_client
271+
mock.return_value = mock_instance
272+
yield mock_instance
273+
274+
254275
@pytest.fixture
255276
async def app():
256277
"""
@@ -286,9 +307,11 @@ async def api_client(app):
286307
Yields:
287308
TestClient: The TestClient instance for API testing.
288309
"""
289-
from src.app import auth
310+
from src.app import oidc_auth
290311

291-
app.dependency_overrides[auth.validated_token] = override_validated_token
312+
app.dependency_overrides[
313+
oidc_auth.valid_token_dependency
314+
] = override_validated_token
292315
base_url = "http://test"
293316

294317
async with AsyncClient(

0 commit comments

Comments
 (0)