From 0a9408148c596992e7961e58efba19abe08a1c2b Mon Sep 17 00:00:00 2001 From: amarouane-ABDELHAK Date: Thu, 10 Jul 2025 16:01:28 -0500 Subject: [PATCH 01/33] new code --- stac_api/runtime/src/app.py | 329 ++++++++++++- stac_api/runtime/src/new_app.py | 600 ++++++++++++++++++++++++ stac_api/runtime/src/old_working_app.py | 565 ++++++++++++++++++++++ 3 files changed, 1489 insertions(+), 5 deletions(-) create mode 100644 stac_api/runtime/src/new_app.py create mode 100644 stac_api/runtime/src/old_working_app.py diff --git a/stac_api/runtime/src/app.py b/stac_api/runtime/src/app.py index acb06260..d5e8e3a1 100644 --- a/stac_api/runtime/src/app.py +++ b/stac_api/runtime/src/app.py @@ -1,8 +1,10 @@ -"""FastAPI application using PGStac. +"""FastAPI application using PGStac with integrated tenant filtering. Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac """ +import json from contextlib import asynccontextmanager +from typing import Dict, Any, Optional from aws_lambda_powertools.metrics import MetricUnit from src.config import TilesApiSettings, api_settings @@ -12,7 +14,7 @@ from src.config import post_request_model as POSTModel from src.extension import TiTilerExtension -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, FastAPI, Request as FastAPIRequest, Depends, Path, HTTPException from fastapi.responses import ORJSONResponse from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from starlette.middleware import Middleware @@ -26,7 +28,7 @@ from .core import VedaCrudClient from .monitoring import LoggerRouteHandler, logger, metrics, tracer from .validation import ValidationMiddleware - +import os from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings try: @@ -42,6 +44,152 @@ auth_settings = OpenIdConnectSettings(_env_prefix="VEDA_STAC_") +class TenantAwareVedaCrudClient(VedaCrudClient): + """Extended CRUD client that applies tenant filtering.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def all_collections(self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get all collections with optional tenant filtering.""" + + # Call the parent method + collections = await super().all_collections(request, **kwargs) + + # If tenant is specified, filter the results + if tenant and hasattr(collections, 'collections'): + filtered_collections = [ + col for col in collections.collections + if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant + ] + collections.collections = filtered_collections + if hasattr(collections, 'context') and hasattr(collections.context, 'returned'): + collections.context.returned = len(filtered_collections) + elif tenant and isinstance(collections, dict) and 'collections' in collections: + filtered_collections = [ + col for col in collections['collections'] + if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant + ] + collections['collections'] = filtered_collections + if 'numberReturned' in collections: + collections['numberReturned'] = len(filtered_collections) + + return collections + + async def _validate_tenant_access(self, collection: dict, tenant: str, collection_id: str = ""): + """Raise HTTP 404 if the collection does not belong to the given tenant.""" + collection_tenant = collection.get("tenant") or collection.get("properties", {}).get("tenant") + if collection_tenant != tenant: + detail = f"Collection {collection_id} not found for tenant {tenant}" if collection_id else "Collection not found" + raise HTTPException(status_code=404, detail=detail) + + async def get_collection(self, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get collection with tenant filtering.""" + collection = await super().get_collection(collection_id, request, **kwargs) + + if tenant and collection: + await self._validate_tenant_access(collection, tenant, collection_id) + + return collection + + async def item_collection( + self, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + limit: int = 10, # Add limit + token: Optional[str] = None, # Add token + **kwargs, + ): + """Get items with tenant filtering.""" + if tenant: + logger.info(f"Filtering items by tenant: {tenant} with token: {token}") + + # Your existing tenant validation logic is good + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}", + ) + await self._validate_tenant_access(collection, tenant, collection_id) + + # Pass the pagination parameters to the parent method + return await super().item_collection( + collection_id=collection_id, + request=request, + limit=limit, + token=token, + **kwargs + ) + async def get_item(self, item_id: str, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get item with tenant filtering.""" + if tenant: + logger.info(f"Filtering item {item_id} in collection {collection_id} by tenant: {tenant}") + + # Fetch and validate the collection belongs to the tenant + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}" + ) + await self._validate_tenant_access(collection, tenant, collection_id) + + return await super().get_item(item_id, collection_id, request, **kwargs) + async def post_search( + self, + search_request, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs + ): + """Search with tenant filtering.""" + if tenant: + logger.info(f"Filtering search by tenant: {tenant}") + tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} + + if search_request.filter: + # If a filter already exists, combine with an 'and' + search_request.filter = { + "op": "and", + "args": [ + search_request.filter, + tenant_filter, + ], + } + else: + search_request.filter = tenant_filter + search_request.filter_lang = "cql2-json" + + return await super().post_search(search_request, request, **kwargs) + + async def get_search( + self, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs, + ): + """GET search with tenant filtering.""" + if tenant: + tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} + + if "filter" in kwargs and kwargs["filter"]: + # Combine with existing filter + kwargs["filter"] = { + "op": "and", + "args": [ + kwargs["filter"], + tenant_filter, + ], + } + else: + kwargs["filter"] = tenant_filter + kwargs["filter-lang"] = "cql2-json" + + # The CoreCrudClient.get_search will use the modified kwargs + return await super().get_search(request, **kwargs) + @asynccontextmanager async def lifespan(app: FastAPI): """Get a database connection on startup, close it on shutdown.""" @@ -72,16 +220,173 @@ async def lifespan(app: FastAPI): description=api_settings.project_description, settings=api_settings, extensions=PgStacExtensions, - client=VedaCrudClient(pgstac_search_model=POSTModel), + client=TenantAwareVedaCrudClient(pgstac_search_model=POSTModel), search_get_request_model=GETModel, search_post_request_model=POSTModel, items_get_request_model=items_get_request_model, response_class=ORJSONResponse, - middlewares=[Middleware(CompressionMiddleware), Middleware(ValidationMiddleware)], + middlewares=[ + Middleware(CompressionMiddleware), + Middleware(ValidationMiddleware), + ], router=APIRouter(route_class=LoggerRouteHandler), ) app = api.app +# Add tenant-specific routes +tenant_router = APIRouter(redirect_slashes=True) + +@tenant_router.get("/{tenant}/collections") +async def get_tenant_collections( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Get collections for a specific tenant.""" + logger.info(f"Getting collections for tenant: {tenant}") + collections = await api.client.all_collections(request, tenant=tenant) + + return collections + + +@tenant_router.get("/{tenant}/collections/{collection_id}") +async def get_tenant_collection( + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + request: FastAPIRequest = None, +): + """Get a specific collection for a tenant.""" + logger.info(f"Getting collection {collection_id} for tenant: {tenant}") + collection = await api.client.get_collection(collection_id, request, tenant=tenant) + + return collection + + +@tenant_router.get("/{tenant}/collections/{collection_id}/items") +async def get_tenant_collection_items( + request: FastAPIRequest, # It's good practice to have request as the first arg + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + limit: int = 10, # Add limit + token: Optional[str] = None, # Add token +): + """Get items from a collection for a specific tenant.""" + logger.info(f"Getting items from collection {collection_id} for tenant: {tenant}") + + # Pass the captured parameters to the client method + items = await api.client.item_collection( + collection_id=collection_id, + request=request, + tenant=tenant, + limit=limit, + token=token + ) + + return items + + +@tenant_router.get("/{tenant}/collections/{collection_id}/items/{item_id}") +async def get_tenant_item( + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + item_id: str = Path(..., description="Item identifier"), + request: FastAPIRequest = None, +): + """Get a specific item for a tenant.""" + logger.info(f"======> Getting item {item_id} from collection {collection_id} for tenant: {tenant} <=====") + return await api.client.get_item(item_id, collection_id, request, tenant=tenant) + + +@tenant_router.get("/{tenant}/search") +async def get_tenant_search( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Search items for a specific tenant using GET.""" + logger.info(f"GET search for tenant: {tenant}") + return await api.client.get_search(request, tenant=tenant) + + +@tenant_router.get("/{tenant}/search") +async def get_tenant_search( + request: FastAPIRequest, # Request should come first + tenant: str = Path(..., description="Tenant identifier"), + # Add ALL possible GET search parameters here that stac-fastapi uses + collections: Optional[str] = None, + ids: Optional[str] = None, + bbox: Optional[str] = None, + datetime: Optional[str] = None, + limit: int = 10, + query: Optional[str] = None, + token: Optional[str] = None, + filter_lang: Optional[str] = None, + filter: Optional[str] = None, + sortby: Optional[str] = None, + # **kwargs: Any # Avoid using this if possible, be explicit +): + """Search items for a specific tenant using GET.""" + logger.info(f"GET search for tenant: {tenant}") + + # The base `get_search` method in stac-fastapi unpacks the request itself. + # What's important is that our `TenantAwareVedaCrudClient.get_search` can access these params. + # The default behavior of stac-fastapi `get_search` is to parse these from the request query params. + # Our modification in the client (step 1) will handle the tenant injection. + + # We create a dictionary of the GET parameters to pass them explicitly + # to avoid ambiguity. + params = { + "collections": collections.split(",") if collections else None, + "ids": ids.split(",") if ids else None, + "bbox": [float(x) for x in bbox.split(",")] if bbox else None, + "datetime": datetime, + "limit": limit, + "query": json.loads(query) if query else None, + "token": token, + "filter-lang": filter_lang, + "filter": json.loads(filter) if filter else None, + "sortby": sortby, + } + # Filter out None values + clean_params = {k: v for k, v in params.items() if v is not None} + + search_result = await api.client.get_search(request, tenant=tenant, **clean_params) + + return search_result + + +@tenant_router.get("/{tenant}/") +async def get_tenant_landing_page( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Get landing page for a specific tenant.""" + logger.info(f"Getting landing page for tenant: {tenant}") + + # Get the base landing page by calling the method on the CLIENT, not the API object + # Corrected line: + base_landing = await api.client.landing_page(request=request) + + # The rest of your logic for modifying the links is correct + if isinstance(base_landing, ORJSONResponse): + # The client returns a response object, so we need to decode its content + body = base_landing.body + tenant_landing = json.loads(body) + + # Update title to include tenant + if 'title' in tenant_landing: + tenant_landing['title'] = f"{tenant.upper()} - {tenant_landing['title']}" + + # Return a new JSONResponse with the modified content + return ORJSONResponse(tenant_landing) + + # Fallback in case the response is not what we expect + return base_landing + +# Include the tenant router +app.include_router(tenant_router, tags=["Tenant-specific endpoints"]) + +# Add tenant-only enforcement middleware (set to False if you want to keep original routes) +# app.add_middleware(TenantOnlyMiddleware, enforce_tenant_only=True) + # Set all CORS enabled origins if api_settings.cors_origins: app.add_middleware( @@ -138,6 +443,20 @@ async def viewer_page(request: Request): ) +@app.get("/{tenant}/index.html", response_class=HTMLResponse) +async def tenant_viewer_page(request: Request, tenant: str): + """Tenant-specific search viewer.""" + return templates.TemplateResponse( + "stac-viewer.html", + { + "request": request, + "endpoint": str(request.url).replace("/index.html", f"/{tenant}"), + "tenant": tenant + }, + media_type="text/html", + ) + + # If the correlation header is used in the UI, we can analyze traces that originate from a given user or client @app.middleware("http") async def add_correlation_id(request: Request, call_next): diff --git a/stac_api/runtime/src/new_app.py b/stac_api/runtime/src/new_app.py new file mode 100644 index 00000000..cf8215d7 --- /dev/null +++ b/stac_api/runtime/src/new_app.py @@ -0,0 +1,600 @@ +"""FastAPI application using PGStac with integrated tenant filtering. +Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac +""" + +import json +from contextlib import asynccontextmanager +from typing import Dict, Any, Optional + +from aws_lambda_powertools.metrics import MetricUnit +from src.config import TilesApiSettings, api_settings +from src.config import extensions as PgStacExtensions +from src.config import get_request_model as GETModel +from src.config import items_get_request_model +from src.config import post_request_model as POSTModel +from src.extension import TiTilerExtension + +from fastapi import APIRouter, FastAPI, Request as FastAPIRequest, Depends, Path, HTTPException +from fastapi.responses import ORJSONResponse +from stac_fastapi.pgstac.db import close_db_connection, connect_to_db +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse, Response +from starlette.templating import Jinja2Templates +from starlette.types import ASGIApp +from starlette_cramjam.middleware import CompressionMiddleware + +from .api import VedaStacApi +from .core import VedaCrudClient +from .monitoring import LoggerRouteHandler, logger, metrics, tracer +from .validation import ValidationMiddleware +import os +from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings + +try: + from importlib.resources import files as resources_files # type: ignore +except ImportError: + # Try backported to PY<39 `importlib_resources`. + from importlib_resources import files as resources_files # type: ignore + + +templates = Jinja2Templates(directory=str(resources_files(__package__) / "templates")) # type: ignore + +def update_links_with_tenant(data: dict, tenant: str) -> dict: + """Update all links in a response to include tenant prefix.""" + base_url = os.getenv('BASE_URL', "http://localhost:8081") + def update_link(link: dict): + if 'href' in link: + href = link['href'] + if href.startswith(f'{base_url}/'): + # Replace base URL with tenant-prefixed URL + link['href'] = href.replace(f'{base_url}/', f'{base_url}/{tenant}/') + elif href.startswith('/') and not href.startswith(f'/{tenant}/'): + # Handle relative URLs + link['href'] = f'/{tenant}{href}' + + # Update main response links + if 'links' in data and isinstance(data['links'], list): + for link in data['links']: + update_link(link) + + # Update collection links + if 'collections' in data and isinstance(data['collections'], list): + for collection in data['collections']: + if 'links' in collection and isinstance(collection['links'], list): + for link in collection['links']: + update_link(link) + + # Update item/feature links + if 'features' in data and isinstance(data['features'], list): + for item in data['features']: + if 'links' in item and isinstance(item['links'], list): + for link in item['links']: + update_link(link) + + return data + + +# class TenantOnlyMiddleware(BaseHTTPMiddleware): +# """Middleware to enforce tenant-only access to STAC endpoints.""" + +# def __init__(self, app: ASGIApp, enforce_tenant_only: bool = True): +# super().__init__(app) +# self.enforce_tenant_only = enforce_tenant_only +# # Endpoints that should be tenant-only +# self.tenant_only_endpoints = { +# '/collections', '/search', '/conformance' +# } +# # Endpoints that are allowed without tenant prefix +# self.allowed_endpoints = { +# '/docs', '/openapi.json', '/index.html', '/health', '/' +# } + +# async def dispatch(self, request: Request, call_next): +# """Check if tenant-only enforcement should be applied.""" + +# if not self.enforce_tenant_only: +# return await call_next(request) + +# path = request.url.path + +# # Allow certain endpoints without tenant +# if any(path.startswith(endpoint) for endpoint in self.allowed_endpoints): +# return await call_next(request) + +# # Check if this is a tenant-only endpoint accessed without tenant +# if any(path.startswith(endpoint) for endpoint in self.tenant_only_endpoints): +# return JSONResponse( +# status_code=400, +# content={ +# "detail": f"This endpoint requires a tenant prefix. Use /{'{tenant}'}{path} instead." +# } +# ) + +# # Allow all other requests +# return await call_next(request) + + +tiles_settings = TilesApiSettings() +auth_settings = OpenIdConnectSettings(_env_prefix="VEDA_STAC_") + + +class TenantAwareVedaCrudClient(VedaCrudClient): + """Extended CRUD client that applies tenant filtering.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def all_collections(self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get all collections with optional tenant filtering.""" + if tenant: + # Add tenant filter to the database query + # This assumes your collections table has a tenant column + # Adjust based on your actual database schema + logger.info(f"Filtering collections by tenant: {tenant}") + + # Call the parent method + collections = await super().all_collections(request, **kwargs) + + # If tenant is specified, filter the results + if tenant and hasattr(collections, 'collections'): + filtered_collections = [ + col for col in collections.collections + if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant + ] + collections.collections = filtered_collections + if hasattr(collections, 'context') and hasattr(collections.context, 'returned'): + collections.context.returned = len(filtered_collections) + elif tenant and isinstance(collections, dict) and 'collections' in collections: + filtered_collections = [ + col for col in collections['collections'] + if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant + ] + collections['collections'] = filtered_collections + if 'numberReturned' in collections: + collections['numberReturned'] = len(filtered_collections) + + return collections + + async def _validate_tenant_access(self, collection: dict, tenant: str, collection_id: str = ""): + """Raise HTTP 404 if the collection does not belong to the given tenant.""" + collection_tenant = collection.get("tenant") or collection.get("properties", {}).get("tenant") + if collection_tenant != tenant: + detail = f"Collection {collection_id} not found for tenant {tenant}" if collection_id else "Collection not found" + raise HTTPException(status_code=404, detail=detail) + + async def get_collection(self, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get collection with tenant filtering.""" + collection = await super().get_collection(collection_id, request, **kwargs) + + if tenant and collection: + await self._validate_tenant_access(collection, tenant, collection_id) + + return collection + + async def item_collection( + self, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + limit: int = 10, # Add limit + token: Optional[str] = None, # Add token + **kwargs, + ): + """Get items with tenant filtering.""" + if tenant: + logger.info(f"Filtering items by tenant: {tenant} with token: {token}") + + # Your existing tenant validation logic is good + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}", + ) + await self._validate_tenant_access(collection, tenant, collection_id) + + # Pass the pagination parameters to the parent method + return await super().item_collection( + collection_id=collection_id, + request=request, + limit=limit, + token=token, + **kwargs + ) + async def get_item(self, item_id: str, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get item with tenant filtering.""" + if tenant: + logger.info(f"Filtering item {item_id} in collection {collection_id} by tenant: {tenant}") + + # Fetch and validate the collection belongs to the tenant + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}" + ) + await self._validate_tenant_access(collection, tenant, collection_id) + + return await super().get_item(item_id, collection_id, request, **kwargs) + async def post_search( + self, + search_request: POSTModel, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs + ): + """Search with tenant filtering.""" + if tenant: + logger.info(f"Filtering search by tenant: {tenant}") + # IMPORTANT: You must actually filter the search by the tenant. + # This assumes you have a 'tenant' property in your collection's 'properties'. + # pgstac search function can take a filter + tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} + + if search_request.filter: + # If a filter already exists, combine with an 'and' + search_request.filter = { + "op": "and", + "args": [ + search_request.filter, + tenant_filter, + ], + } + else: + search_request.filter = tenant_filter + search_request.filter_lang = "cql2-json" + + return await super().post_search(search_request, request, **kwargs) + + async def get_search( + self, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs, + ): + """GET search with tenant filtering.""" + if tenant: + logger.info(f"Filtering GET search by tenant: {tenant}") + # IMPORTANT: You must also modify the GET search to filter by tenant. + # This requires modifying the kwargs that will be used to build the search request. + tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} + + if "filter" in kwargs and kwargs["filter"]: + # Combine with existing filter + kwargs["filter"] = { + "op": "and", + "args": [ + kwargs["filter"], + tenant_filter, + ], + } + else: + kwargs["filter"] = tenant_filter + kwargs["filter-lang"] = "cql2-json" + + # The CoreCrudClient.get_search will use the modified kwargs + return await super().get_search(request, **kwargs) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Get a database connection on startup, close it on shutdown.""" + await connect_to_db(app) + yield + await close_db_connection(app) + + +# Create the base STAC API +api = VedaStacApi( + app=FastAPI( + title=f"{api_settings.project_name} STAC API", + openapi_url="/openapi.json", + docs_url="/docs", + root_path=api_settings.root_path, + swagger_ui_init_oauth=( + { + "appName": "STAC API", + "clientId": auth_settings.client_id, + "usePkceWithAuthorizationCodeGrant": True, + "scopes": "openid stac:item:create stac:item:update stac:item:delete stac:collection:create stac:collection:update stac:collection:delete", + } + if auth_settings.client_id + else {} + ), + lifespan=lifespan, + ), + title=f"{api_settings.project_name} STAC API", + description=api_settings.project_description, + settings=api_settings.load_postgres_settings(), + extensions=PgStacExtensions, + client=TenantAwareVedaCrudClient(post_request_model=POSTModel), + search_get_request_model=GETModel, + search_post_request_model=POSTModel, + items_get_request_model=items_get_request_model, + response_class=ORJSONResponse, + middlewares=[ + Middleware(CompressionMiddleware), + Middleware(ValidationMiddleware), + ], + router=APIRouter(route_class=LoggerRouteHandler), +) +app = api.app + +# Add tenant-specific routes +tenant_router = APIRouter(redirect_slashes=True) + +@tenant_router.get("/{tenant}/collections") +async def get_tenant_collections( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Get collections for a specific tenant.""" + logger.info(f"Getting collections for tenant: {tenant}") + collections = await api.client.all_collections(request, tenant=tenant) + + # Update links to include tenant prefix + if collections and isinstance(collections, dict): + collections = update_links_with_tenant(collections, tenant) + + return collections + + +@tenant_router.get("/{tenant}/collections/{collection_id}") +async def get_tenant_collection( + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + request: FastAPIRequest = None, +): + """Get a specific collection for a tenant.""" + logger.info(f"Getting collection {collection_id} for tenant: {tenant}") + collection = await api.client.get_collection(collection_id, request, tenant=tenant) + + # Update links to include tenant prefix + if collection and isinstance(collection, dict): + collection = update_links_with_tenant(collection, tenant) + + return collection + + +@tenant_router.get("/{tenant}/collections/{collection_id}/items") +async def get_tenant_collection_items( + request: FastAPIRequest, # It's good practice to have request as the first arg + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + limit: int = 10, # Add limit + token: Optional[str] = None, # Add token +): + """Get items from a collection for a specific tenant.""" + logger.info(f"Getting items from collection {collection_id} for tenant: {tenant}") + + # Pass the captured parameters to the client method + items = await api.client.item_collection( + collection_id=collection_id, + request=request, + tenant=tenant, + limit=limit, + token=token + ) + + # Your link updater will correctly rewrite the new `next` link if one is present + if items and isinstance(items, dict): + items = update_links_with_tenant(items, tenant) + + return items + + +@tenant_router.get("/{tenant}/collections/{collection_id}/items/{item_id}") +async def get_tenant_item( + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + item_id: str = Path(..., description="Item identifier"), + request: FastAPIRequest = None, +): + """Get a specific item for a tenant.""" + logger.info(f"======> Getting item {item_id} from collection {collection_id} for tenant: {tenant} <=====") + return await api.client.get_item(item_id, collection_id, request, tenant=tenant) + + +@tenant_router.get("/{tenant}/search") +async def get_tenant_search( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Search items for a specific tenant using GET.""" + logger.info(f"GET search for tenant: {tenant}") + return await api.client.get_search(request, tenant=tenant) + + +@tenant_router.get("/{tenant}/search") +async def get_tenant_search( + request: FastAPIRequest, # Request should come first + tenant: str = Path(..., description="Tenant identifier"), + # Add ALL possible GET search parameters here that stac-fastapi uses + collections: Optional[str] = None, + ids: Optional[str] = None, + bbox: Optional[str] = None, + datetime: Optional[str] = None, + limit: int = 10, + query: Optional[str] = None, + token: Optional[str] = None, + filter_lang: Optional[str] = None, + filter: Optional[str] = None, + sortby: Optional[str] = None, + # **kwargs: Any # Avoid using this if possible, be explicit +): + """Search items for a specific tenant using GET.""" + logger.info(f"GET search for tenant: {tenant}") + + # The base `get_search` method in stac-fastapi unpacks the request itself. + # What's important is that our `TenantAwareVedaCrudClient.get_search` can access these params. + # The default behavior of stac-fastapi `get_search` is to parse these from the request query params. + # Our modification in the client (step 1) will handle the tenant injection. + + # We create a dictionary of the GET parameters to pass them explicitly + # to avoid ambiguity. + params = { + "collections": collections.split(",") if collections else None, + "ids": ids.split(",") if ids else None, + "bbox": [float(x) for x in bbox.split(",")] if bbox else None, + "datetime": datetime, + "limit": limit, + "query": json.loads(query) if query else None, + "token": token, + "filter-lang": filter_lang, + "filter": json.loads(filter) if filter else None, + "sortby": sortby, + } + # Filter out None values + clean_params = {k: v for k, v in params.items() if v is not None} + + search_result = await api.client.get_search(request, tenant=tenant, **clean_params) + + # Update links + if search_result and isinstance(search_result, dict): + search_result = update_links_with_tenant(search_result, tenant) + + return search_result + + +@tenant_router.get("/{tenant}/") +async def get_tenant_landing_page( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Get landing page for a specific tenant.""" + logger.info(f"Getting landing page for tenant: {tenant}") + + # Get the base landing page by calling the method on the CLIENT, not the API object + # Corrected line: + base_landing = await api.client.landing_page(request=request) + + # The rest of your logic for modifying the links is correct + if isinstance(base_landing, ORJSONResponse): + # The client returns a response object, so we need to decode its content + body = base_landing.body + tenant_landing = json.loads(body) + + # Update links to include tenant prefix + if 'links' in tenant_landing: + # Using your update_links_with_tenant function is more robust + tenant_landing = update_links_with_tenant(tenant_landing, tenant) + + # Update title to include tenant + if 'title' in tenant_landing: + tenant_landing['title'] = f"{tenant.upper()} - {tenant_landing['title']}" + + # Return a new JSONResponse with the modified content + return ORJSONResponse(tenant_landing) + + # Fallback in case the response is not what we expect + return base_landing + +# Include the tenant router +app.include_router(tenant_router, tags=["Tenant-specific endpoints"]) + +# Add tenant-only enforcement middleware (set to False if you want to keep original routes) +# app.add_middleware(TenantOnlyMiddleware, enforce_tenant_only=True) + +# Set all CORS enabled origins +if api_settings.cors_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=api_settings.cors_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "OPTIONS"], + allow_headers=["*"], + ) + +if api_settings.enable_transactions and auth_settings.client_id: + oidc_auth = OpenIdConnectAuth( + openid_configuration_url=auth_settings.openid_configuration_url, + allowed_jwt_audiences="account", + ) + + restricted_prefixes_methods = { + "/collections": [("POST", "stac:collection:create")], + "/collections/{collection_id}": [ + ("PUT", "stac:collection:update"), + ("DELETE", "stac:collection:delete"), + ], + "/collections/{collection_id}/items": [("POST", "stac:item:create")], + "/collections/{collection_id}/items/{item_id}": [ + ("PUT", "stac:item:update"), + ("DELETE", "stac:item:delete"), + ], + "/collections/{collection_id}/bulk_items": [("POST", "stac:item:create")], + } + + for route in app.router.routes: + method_scopes = restricted_prefixes_methods.get(route.path) + if not method_scopes: + continue + for method, scope in method_scopes: + if method not in route.methods: + continue + oidc_auth.apply_auth_dependencies(route, required_token_scopes=[scope]) + +if tiles_settings.titiler_endpoint: + # Register to the TiTiler extension to the api + extension = TiTilerExtension() + extension.register(api.app, tiles_settings.titiler_endpoint) + + +@app.get("/index.html", response_class=HTMLResponse) +async def viewer_page(request: Request): + """Search viewer.""" + path = api_settings.root_path or "" + return templates.TemplateResponse( + "stac-viewer.html", + {"request": request, "endpoint": str(request.url).replace("/index.html", path)}, + media_type="text/html", + ) + + +@app.get("/{tenant}/index.html", response_class=HTMLResponse) +async def tenant_viewer_page(request: Request, tenant: str): + """Tenant-specific search viewer.""" + return templates.TemplateResponse( + "stac-viewer.html", + { + "request": request, + "endpoint": str(request.url).replace("/index.html", f"/{tenant}"), + "tenant": tenant + }, + media_type="text/html", + ) + + +# If the correlation header is used in the UI, we can analyze traces that originate from a given user or client +@app.middleware("http") +async def add_correlation_id(request: Request, call_next): + """Add correlation ids to all requests and subsequent logs/traces""" + # Get correlation id from X-Correlation-Id header + corr_id = request.headers.get("x-correlation-id") + if not corr_id: + try: + # If empty, use request id from aws context + corr_id = request.scope["aws.context"].aws_request_id + except KeyError: + # If empty, use uuid + corr_id = "local" + # Add correlation id to logs + logger.set_correlation_id(corr_id) + # Add correlation id to traces + tracer.put_annotation(key="correlation_id", value=corr_id) + + response = await tracer.capture_method(call_next)(request) + # Return correlation header in response + response.headers["X-Correlation-Id"] = corr_id + logger.info("Request completed") + return response + + +@app.exception_handler(Exception) +async def validation_exception_handler(request, err): + """Handle exceptions that aren't caught elsewhere""" + metrics.add_metric(name="UnhandledExceptions", unit=MetricUnit.Count, value=1) + logger.error("Unhandled exception") + return JSONResponse(status_code=500, content={"detail": "Internal Server Error"}) diff --git a/stac_api/runtime/src/old_working_app.py b/stac_api/runtime/src/old_working_app.py new file mode 100644 index 00000000..1a9f18d0 --- /dev/null +++ b/stac_api/runtime/src/old_working_app.py @@ -0,0 +1,565 @@ +"""FastAPI application using PGStac with integrated tenant filtering. +Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac +""" + +import json +from contextlib import asynccontextmanager +from typing import Dict, Any, Optional + +from aws_lambda_powertools.metrics import MetricUnit +from src.config import TilesApiSettings, api_settings +from src.config import extensions as PgStacExtensions +from src.config import get_request_model as GETModel +from src.config import items_get_request_model +from src.config import post_request_model as POSTModel +from src.extension import TiTilerExtension + +from fastapi import APIRouter, FastAPI, Request as FastAPIRequest, Depends, Path, HTTPException +from fastapi.responses import ORJSONResponse +from stac_fastapi.pgstac.db import close_db_connection, connect_to_db +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse, Response +from starlette.templating import Jinja2Templates +from starlette.types import ASGIApp +from starlette_cramjam.middleware import CompressionMiddleware + +from .api import VedaStacApi +from .core import VedaCrudClient +from .monitoring import LoggerRouteHandler, logger, metrics, tracer +from .validation import ValidationMiddleware +import os +from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings + +try: + from importlib.resources import files as resources_files # type: ignore +except ImportError: + # Try backported to PY<39 `importlib_resources`. + from importlib_resources import files as resources_files # type: ignore + + +templates = Jinja2Templates(directory=str(resources_files(__package__) / "templates")) # type: ignore + +def update_links_with_tenant(data: dict, tenant: str) -> dict: + """Update all links in a response to include tenant prefix.""" + base_url = os.getenv('BASE_URL', "http://localhost:8081") + def update_link(link: dict): + if 'href' in link and not "token=next:" in link['href']: + href = link['href'] + if href.startswith(f'{base_url}/'): + # Replace base URL with tenant-prefixed URL + link['href'] = href.replace(f'{base_url}/', f'{base_url}/{tenant}/') + elif href.startswith('/') and not href.startswith(f'/{tenant}/'): + # Handle relative URLs + link['href'] = f'/{tenant}{href}' + + # Update main response links + if 'links' in data and isinstance(data['links'], list): + for link in data['links']: + update_link(link) + + # Update collection links + if 'collections' in data and isinstance(data['collections'], list): + for collection in data['collections']: + if 'links' in collection and isinstance(collection['links'], list): + for link in collection['links']: + update_link(link) + + # Update item/feature links + if 'features' in data and isinstance(data['features'], list): + for item in data['features']: + if 'links' in item and isinstance(item['links'], list): + for link in item['links']: + update_link(link) + + return data + + + +tiles_settings = TilesApiSettings() +auth_settings = OpenIdConnectSettings(_env_prefix="VEDA_STAC_") + + +class TenantAwareVedaCrudClient(VedaCrudClient): + """Extended CRUD client that applies tenant filtering.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def all_collections(self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get all collections with optional tenant filtering.""" + if tenant: + # Add tenant filter to the database query + # This assumes your collections table has a tenant column + # Adjust based on your actual database schema + logger.info(f"Filtering collections by tenant: {tenant}") + + # Call the parent method + collections = await super().all_collections(request, **kwargs) + + # If tenant is specified, filter the results + if tenant and hasattr(collections, 'collections'): + filtered_collections = [ + col for col in collections.collections + if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant + ] + collections.collections = filtered_collections + if hasattr(collections, 'context') and hasattr(collections.context, 'returned'): + collections.context.returned = len(filtered_collections) + elif tenant and isinstance(collections, dict) and 'collections' in collections: + filtered_collections = [ + col for col in collections['collections'] + if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant + ] + collections['collections'] = filtered_collections + if 'numberReturned' in collections: + collections['numberReturned'] = len(filtered_collections) + + return collections + + async def _validate_tenant_access(self, collection: dict, tenant: str, collection_id: str = ""): + """Raise HTTP 404 if the collection does not belong to the given tenant.""" + collection_tenant = collection.get("tenant") or collection.get("properties", {}).get("tenant") + if collection_tenant != tenant: + detail = f"Collection {collection_id} not found for tenant {tenant}" if collection_id else "Collection not found" + raise HTTPException(status_code=404, detail=detail) + + async def get_collection(self, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get collection with tenant filtering.""" + collection = await super().get_collection(collection_id, request, **kwargs) + + if tenant and collection: + await self._validate_tenant_access(collection, tenant, collection_id) + + return collection + + async def item_collection( + self, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + limit: int = 10, # Add limit + token: Optional[str] = None, # Add token + **kwargs, + ): + """Get items with tenant filtering.""" + if tenant: + logger.info(f"Filtering items by tenant: {tenant} with token: {token}") + + # Your existing tenant validation logic is good + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}", + ) + await self._validate_tenant_access(collection, tenant, collection_id) + + # Pass the pagination parameters to the parent method + return await super().item_collection( + collection_id=collection_id, + request=request, + limit=limit, + token=token, + **kwargs + ) + async def get_item(self, item_id: str, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): + """Get item with tenant filtering.""" + if tenant: + logger.info(f"Filtering item {item_id} in collection {collection_id} by tenant: {tenant}") + + # Fetch and validate the collection belongs to the tenant + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}" + ) + await self._validate_tenant_access(collection, tenant, collection_id) + + return await super().get_item(item_id, collection_id, request, **kwargs) + async def post_search( + self, + search_request: POSTModel, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs + ): + """Search with tenant filtering.""" + if tenant: + logger.info(f"Filtering search by tenant: {tenant}") + # IMPORTANT: You must actually filter the search by the tenant. + # This assumes you have a 'tenant' property in your collection's 'properties'. + # pgstac search function can take a filter + tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} + + if search_request.filter: + # If a filter already exists, combine with an 'and' + search_request.filter = { + "op": "and", + "args": [ + search_request.filter, + tenant_filter, + ], + } + else: + search_request.filter = tenant_filter + search_request.filter_lang = "cql2-json" + + return await super().post_search(search_request, request, **kwargs) + + async def get_search( + self, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs, + ): + """GET search with tenant filtering.""" + if tenant: + logger.info(f"Filtering GET search by tenant: {tenant}") + # IMPORTANT: You must also modify the GET search to filter by tenant. + # This requires modifying the kwargs that will be used to build the search request. + tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} + + if "filter" in kwargs and kwargs["filter"]: + # Combine with existing filter + kwargs["filter"] = { + "op": "and", + "args": [ + kwargs["filter"], + tenant_filter, + ], + } + else: + kwargs["filter"] = tenant_filter + kwargs["filter-lang"] = "cql2-json" + + # The CoreCrudClient.get_search will use the modified kwargs + return await super().get_search(request, **kwargs) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Get a database connection on startup, close it on shutdown.""" + await connect_to_db(app) + yield + await close_db_connection(app) + + +# Create the base STAC API +api = VedaStacApi( + app=FastAPI( + title=f"{api_settings.project_name} STAC API", + openapi_url="/openapi.json", + docs_url="/docs", + root_path=api_settings.root_path, + swagger_ui_init_oauth=( + { + "appName": "STAC API", + "clientId": auth_settings.client_id, + "usePkceWithAuthorizationCodeGrant": True, + "scopes": "openid stac:item:create stac:item:update stac:item:delete stac:collection:create stac:collection:update stac:collection:delete", + } + if auth_settings.client_id + else {} + ), + lifespan=lifespan, + ), + title=f"{api_settings.project_name} STAC API", + description=api_settings.project_description, + settings=api_settings.load_postgres_settings(), + extensions=PgStacExtensions, + client=TenantAwareVedaCrudClient(post_request_model=POSTModel), + search_get_request_model=GETModel, + search_post_request_model=POSTModel, + items_get_request_model=items_get_request_model, + response_class=ORJSONResponse, + middlewares=[ + Middleware(CompressionMiddleware), + Middleware(ValidationMiddleware), + ], + router=APIRouter(route_class=LoggerRouteHandler), +) +app = api.app + +# Add tenant-specific routes +tenant_router = APIRouter(redirect_slashes=True) + +@tenant_router.get("/{tenant}/collections") +async def get_tenant_collections( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Get collections for a specific tenant.""" + logger.info(f"Getting collections for tenant: {tenant}") + collections = await api.client.all_collections(request, tenant=tenant) + + # Update links to include tenant prefix + if collections and isinstance(collections, dict): + collections = update_links_with_tenant(collections, tenant) + + return collections + + +@tenant_router.get("/{tenant}/collections/{collection_id}") +async def get_tenant_collection( + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + request: FastAPIRequest = None, +): + """Get a specific collection for a tenant.""" + logger.info(f"Getting collection {collection_id} for tenant: {tenant}") + collection = await api.client.get_collection(collection_id, request, tenant=tenant) + + # Update links to include tenant prefix + if collection and isinstance(collection, dict): + collection = update_links_with_tenant(collection, tenant) + + return collection + + +@tenant_router.get("/{tenant}/collections/{collection_id}/items") +async def get_tenant_collection_items( + request: FastAPIRequest, # It's good practice to have request as the first arg + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + limit: int = 10, # Add limit + token: Optional[str] = None, # Add token +): + """Get items from a collection for a specific tenant.""" + logger.info(f"Getting items from collection {collection_id} for tenant: {tenant}") + + # Pass the captured parameters to the client method + items = await api.client.item_collection( + collection_id=collection_id, + request=request, + tenant=tenant, + limit=limit, + token=token + ) + + # Your link updater will correctly rewrite the new `next` link if one is present + if items and isinstance(items, dict): + items = update_links_with_tenant(items, tenant) + + return items + + +@tenant_router.get("/{tenant}/collections/{collection_id}/items/{item_id}") +async def get_tenant_item( + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + item_id: str = Path(..., description="Item identifier"), + request: FastAPIRequest = None, +): + """Get a specific item for a tenant.""" + logger.info(f"======> Getting item {item_id} from collection {collection_id} for tenant: {tenant} <=====") + + item = await api.client.get_item(item_id, collection_id, request, tenant=tenant) + if item and isinstance(item, dict): + item = update_links_with_tenant(item, tenant) + + return item + +@tenant_router.get("/{tenant}/search") +async def get_tenant_search( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Search items for a specific tenant using GET.""" + logger.info(f"GET search for tenant: {tenant}") + return await api.client.get_search(request, tenant=tenant) + + +@tenant_router.get("/{tenant}/search") +async def get_tenant_search( + request: FastAPIRequest, # Request should come first + tenant: str = Path(..., description="Tenant identifier"), + # Add ALL possible GET search parameters here that stac-fastapi uses + collections: Optional[str] = None, + ids: Optional[str] = None, + bbox: Optional[str] = None, + datetime: Optional[str] = None, + limit: int = 10, + query: Optional[str] = None, + token: Optional[str] = None, + filter_lang: Optional[str] = None, + filter: Optional[str] = None, + sortby: Optional[str] = None, + # **kwargs: Any # Avoid using this if possible, be explicit +): + """Search items for a specific tenant using GET.""" + logger.info(f"GET search for tenant: {tenant}") + + # The base `get_search` method in stac-fastapi unpacks the request itself. + # What's important is that our `TenantAwareVedaCrudClient.get_search` can access these params. + # The default behavior of stac-fastapi `get_search` is to parse these from the request query params. + # Our modification in the client (step 1) will handle the tenant injection. + + # We create a dictionary of the GET parameters to pass them explicitly + # to avoid ambiguity. + params = { + "collections": collections.split(",") if collections else None, + "ids": ids.split(",") if ids else None, + "bbox": [float(x) for x in bbox.split(",")] if bbox else None, + "datetime": datetime, + "limit": limit, + "query": json.loads(query) if query else None, + "token": token, + "filter-lang": filter_lang, + "filter": json.loads(filter) if filter else None, + "sortby": sortby, + } + # Filter out None values + clean_params = {k: v for k, v in params.items() if v is not None} + + search_result = await api.client.get_search(request, tenant=tenant, **clean_params) + + # Update links + if search_result and isinstance(search_result, dict): + search_result = update_links_with_tenant(search_result, tenant) + + return search_result + +@tenant_router.get("/{tenant}/index.html", response_class=HTMLResponse) +async def tenant_viewer_page(request: Request, tenant: str): + """Tenant-specific search viewer.""" + return templates.TemplateResponse( + "stac-viewer.html", + { + "request": request, + "endpoint": str(request.url).replace("/index.html", f"/{tenant}"), + "tenant": tenant + }, + media_type="text/html", + ) +@tenant_router.get("/{tenant}/") +async def get_tenant_landing_page( + tenant: str = Path(..., description="Tenant identifier"), + request: FastAPIRequest = None, +): + """Get landing page for a specific tenant.""" + logger.info(f"Getting landing page for tenant: {tenant}") + + # Get the base landing page by calling the method on the CLIENT, not the API object + # Corrected line: + base_landing = await api.client.landing_page(request=request) + + # The rest of your logic for modifying the links is correct + if isinstance(base_landing, ORJSONResponse): + # The client returns a response object, so we need to decode its content + body = base_landing.body + tenant_landing = json.loads(body) + + # Update links to include tenant prefix + if 'links' in tenant_landing: + # Using your update_links_with_tenant function is more robust + tenant_landing = update_links_with_tenant(tenant_landing, tenant) + + # Update title to include tenant + if 'title' in tenant_landing: + tenant_landing['title'] = f"{tenant.upper()} - {tenant_landing['title']}" + + # Return a new JSONResponse with the modified content + return ORJSONResponse(tenant_landing) + + # Fallback in case the response is not what we expect + return base_landing + +# Include the tenant router +app.include_router(tenant_router, tags=["Tenant-specific endpoints"]) + +# Add tenant-only enforcement middleware (set to False if you want to keep original routes) +# app.add_middleware(TenantOnlyMiddleware, enforce_tenant_only=True) + +# Set all CORS enabled origins +if api_settings.cors_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=api_settings.cors_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "OPTIONS"], + allow_headers=["*"], + ) + +if api_settings.enable_transactions and auth_settings.client_id: + oidc_auth = OpenIdConnectAuth( + openid_configuration_url=auth_settings.openid_configuration_url, + allowed_jwt_audiences="account", + ) + + restricted_prefixes_methods = { + "/collections": [("POST", "stac:collection:create")], + "/collections/{collection_id}": [ + ("PUT", "stac:collection:update"), + ("DELETE", "stac:collection:delete"), + ], + "/collections/{collection_id}/items": [("POST", "stac:item:create")], + "/collections/{collection_id}/items/{item_id}": [ + ("PUT", "stac:item:update"), + ("DELETE", "stac:item:delete"), + ], + "/collections/{collection_id}/bulk_items": [("POST", "stac:item:create")], + } + + for route in app.router.routes: + method_scopes = restricted_prefixes_methods.get(route.path) + if not method_scopes: + continue + for method, scope in method_scopes: + if method not in route.methods: + continue + oidc_auth.apply_auth_dependencies(route, required_token_scopes=[scope]) + +if tiles_settings.titiler_endpoint: + # Register to the TiTiler extension to the api + extension = TiTilerExtension() + extension.register(api.app, tiles_settings.titiler_endpoint) + + +@app.get("/index.html", response_class=HTMLResponse) +async def viewer_page(request: Request): + """Search viewer.""" + path = api_settings.root_path or "" + return templates.TemplateResponse( + "stac-viewer.html", + {"request": request, "endpoint": str(request.url).replace("/index.html", path)}, + media_type="text/html", + ) + + + + + +# If the correlation header is used in the UI, we can analyze traces that originate from a given user or client +@app.middleware("http") +async def add_correlation_id(request: Request, call_next): + """Add correlation ids to all requests and subsequent logs/traces""" + # Get correlation id from X-Correlation-Id header + corr_id = request.headers.get("x-correlation-id") + if not corr_id: + try: + # If empty, use request id from aws context + corr_id = request.scope["aws.context"].aws_request_id + except KeyError: + # If empty, use uuid + corr_id = "local" + # Add correlation id to logs + logger.set_correlation_id(corr_id) + # Add correlation id to traces + tracer.put_annotation(key="correlation_id", value=corr_id) + + response = await tracer.capture_method(call_next)(request) + # Return correlation header in response + response.headers["X-Correlation-Id"] = corr_id + logger.info("Request completed") + return response + + +@app.exception_handler(Exception) +async def validation_exception_handler(request, err): + """Handle exceptions that aren't caught elsewhere""" + metrics.add_metric(name="UnhandledExceptions", unit=MetricUnit.Count, value=1) + logger.error("Unhandled exception") + return JSONResponse(status_code=500, content={"detail": "Internal Server Error"}) From 0b59cd862532becaa5e54809c6ba3860a6cef541 Mon Sep 17 00:00:00 2001 From: amarouane-ABDELHAK Date: Thu, 10 Jul 2025 16:02:29 -0500 Subject: [PATCH 02/33] Support STAC-API multi tenants PoC --- stac_api/runtime/src/new_app.py | 600 ------------------------ stac_api/runtime/src/old_working_app.py | 565 ---------------------- 2 files changed, 1165 deletions(-) delete mode 100644 stac_api/runtime/src/new_app.py delete mode 100644 stac_api/runtime/src/old_working_app.py diff --git a/stac_api/runtime/src/new_app.py b/stac_api/runtime/src/new_app.py deleted file mode 100644 index cf8215d7..00000000 --- a/stac_api/runtime/src/new_app.py +++ /dev/null @@ -1,600 +0,0 @@ -"""FastAPI application using PGStac with integrated tenant filtering. -Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac -""" - -import json -from contextlib import asynccontextmanager -from typing import Dict, Any, Optional - -from aws_lambda_powertools.metrics import MetricUnit -from src.config import TilesApiSettings, api_settings -from src.config import extensions as PgStacExtensions -from src.config import get_request_model as GETModel -from src.config import items_get_request_model -from src.config import post_request_model as POSTModel -from src.extension import TiTilerExtension - -from fastapi import APIRouter, FastAPI, Request as FastAPIRequest, Depends, Path, HTTPException -from fastapi.responses import ORJSONResponse -from stac_fastapi.pgstac.db import close_db_connection, connect_to_db -from starlette.middleware import Middleware -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.middleware.cors import CORSMiddleware -from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, Response -from starlette.templating import Jinja2Templates -from starlette.types import ASGIApp -from starlette_cramjam.middleware import CompressionMiddleware - -from .api import VedaStacApi -from .core import VedaCrudClient -from .monitoring import LoggerRouteHandler, logger, metrics, tracer -from .validation import ValidationMiddleware -import os -from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings - -try: - from importlib.resources import files as resources_files # type: ignore -except ImportError: - # Try backported to PY<39 `importlib_resources`. - from importlib_resources import files as resources_files # type: ignore - - -templates = Jinja2Templates(directory=str(resources_files(__package__) / "templates")) # type: ignore - -def update_links_with_tenant(data: dict, tenant: str) -> dict: - """Update all links in a response to include tenant prefix.""" - base_url = os.getenv('BASE_URL', "http://localhost:8081") - def update_link(link: dict): - if 'href' in link: - href = link['href'] - if href.startswith(f'{base_url}/'): - # Replace base URL with tenant-prefixed URL - link['href'] = href.replace(f'{base_url}/', f'{base_url}/{tenant}/') - elif href.startswith('/') and not href.startswith(f'/{tenant}/'): - # Handle relative URLs - link['href'] = f'/{tenant}{href}' - - # Update main response links - if 'links' in data and isinstance(data['links'], list): - for link in data['links']: - update_link(link) - - # Update collection links - if 'collections' in data and isinstance(data['collections'], list): - for collection in data['collections']: - if 'links' in collection and isinstance(collection['links'], list): - for link in collection['links']: - update_link(link) - - # Update item/feature links - if 'features' in data and isinstance(data['features'], list): - for item in data['features']: - if 'links' in item and isinstance(item['links'], list): - for link in item['links']: - update_link(link) - - return data - - -# class TenantOnlyMiddleware(BaseHTTPMiddleware): -# """Middleware to enforce tenant-only access to STAC endpoints.""" - -# def __init__(self, app: ASGIApp, enforce_tenant_only: bool = True): -# super().__init__(app) -# self.enforce_tenant_only = enforce_tenant_only -# # Endpoints that should be tenant-only -# self.tenant_only_endpoints = { -# '/collections', '/search', '/conformance' -# } -# # Endpoints that are allowed without tenant prefix -# self.allowed_endpoints = { -# '/docs', '/openapi.json', '/index.html', '/health', '/' -# } - -# async def dispatch(self, request: Request, call_next): -# """Check if tenant-only enforcement should be applied.""" - -# if not self.enforce_tenant_only: -# return await call_next(request) - -# path = request.url.path - -# # Allow certain endpoints without tenant -# if any(path.startswith(endpoint) for endpoint in self.allowed_endpoints): -# return await call_next(request) - -# # Check if this is a tenant-only endpoint accessed without tenant -# if any(path.startswith(endpoint) for endpoint in self.tenant_only_endpoints): -# return JSONResponse( -# status_code=400, -# content={ -# "detail": f"This endpoint requires a tenant prefix. Use /{'{tenant}'}{path} instead." -# } -# ) - -# # Allow all other requests -# return await call_next(request) - - -tiles_settings = TilesApiSettings() -auth_settings = OpenIdConnectSettings(_env_prefix="VEDA_STAC_") - - -class TenantAwareVedaCrudClient(VedaCrudClient): - """Extended CRUD client that applies tenant filtering.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - async def all_collections(self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get all collections with optional tenant filtering.""" - if tenant: - # Add tenant filter to the database query - # This assumes your collections table has a tenant column - # Adjust based on your actual database schema - logger.info(f"Filtering collections by tenant: {tenant}") - - # Call the parent method - collections = await super().all_collections(request, **kwargs) - - # If tenant is specified, filter the results - if tenant and hasattr(collections, 'collections'): - filtered_collections = [ - col for col in collections.collections - if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant - ] - collections.collections = filtered_collections - if hasattr(collections, 'context') and hasattr(collections.context, 'returned'): - collections.context.returned = len(filtered_collections) - elif tenant and isinstance(collections, dict) and 'collections' in collections: - filtered_collections = [ - col for col in collections['collections'] - if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant - ] - collections['collections'] = filtered_collections - if 'numberReturned' in collections: - collections['numberReturned'] = len(filtered_collections) - - return collections - - async def _validate_tenant_access(self, collection: dict, tenant: str, collection_id: str = ""): - """Raise HTTP 404 if the collection does not belong to the given tenant.""" - collection_tenant = collection.get("tenant") or collection.get("properties", {}).get("tenant") - if collection_tenant != tenant: - detail = f"Collection {collection_id} not found for tenant {tenant}" if collection_id else "Collection not found" - raise HTTPException(status_code=404, detail=detail) - - async def get_collection(self, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get collection with tenant filtering.""" - collection = await super().get_collection(collection_id, request, **kwargs) - - if tenant and collection: - await self._validate_tenant_access(collection, tenant, collection_id) - - return collection - - async def item_collection( - self, - collection_id: str, - request: FastAPIRequest, - tenant: Optional[str] = None, - limit: int = 10, # Add limit - token: Optional[str] = None, # Add token - **kwargs, - ): - """Get items with tenant filtering.""" - if tenant: - logger.info(f"Filtering items by tenant: {tenant} with token: {token}") - - # Your existing tenant validation logic is good - collection = await super().get_collection(collection_id, request, **kwargs) - if not collection: - raise HTTPException( - status_code=404, - detail=f"Collection {collection_id} not found for tenant {tenant}", - ) - await self._validate_tenant_access(collection, tenant, collection_id) - - # Pass the pagination parameters to the parent method - return await super().item_collection( - collection_id=collection_id, - request=request, - limit=limit, - token=token, - **kwargs - ) - async def get_item(self, item_id: str, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get item with tenant filtering.""" - if tenant: - logger.info(f"Filtering item {item_id} in collection {collection_id} by tenant: {tenant}") - - # Fetch and validate the collection belongs to the tenant - collection = await super().get_collection(collection_id, request, **kwargs) - if not collection: - raise HTTPException( - status_code=404, - detail=f"Collection {collection_id} not found for tenant {tenant}" - ) - await self._validate_tenant_access(collection, tenant, collection_id) - - return await super().get_item(item_id, collection_id, request, **kwargs) - async def post_search( - self, - search_request: POSTModel, - request: FastAPIRequest, - tenant: Optional[str] = None, - **kwargs - ): - """Search with tenant filtering.""" - if tenant: - logger.info(f"Filtering search by tenant: {tenant}") - # IMPORTANT: You must actually filter the search by the tenant. - # This assumes you have a 'tenant' property in your collection's 'properties'. - # pgstac search function can take a filter - tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} - - if search_request.filter: - # If a filter already exists, combine with an 'and' - search_request.filter = { - "op": "and", - "args": [ - search_request.filter, - tenant_filter, - ], - } - else: - search_request.filter = tenant_filter - search_request.filter_lang = "cql2-json" - - return await super().post_search(search_request, request, **kwargs) - - async def get_search( - self, - request: FastAPIRequest, - tenant: Optional[str] = None, - **kwargs, - ): - """GET search with tenant filtering.""" - if tenant: - logger.info(f"Filtering GET search by tenant: {tenant}") - # IMPORTANT: You must also modify the GET search to filter by tenant. - # This requires modifying the kwargs that will be used to build the search request. - tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} - - if "filter" in kwargs and kwargs["filter"]: - # Combine with existing filter - kwargs["filter"] = { - "op": "and", - "args": [ - kwargs["filter"], - tenant_filter, - ], - } - else: - kwargs["filter"] = tenant_filter - kwargs["filter-lang"] = "cql2-json" - - # The CoreCrudClient.get_search will use the modified kwargs - return await super().get_search(request, **kwargs) - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Get a database connection on startup, close it on shutdown.""" - await connect_to_db(app) - yield - await close_db_connection(app) - - -# Create the base STAC API -api = VedaStacApi( - app=FastAPI( - title=f"{api_settings.project_name} STAC API", - openapi_url="/openapi.json", - docs_url="/docs", - root_path=api_settings.root_path, - swagger_ui_init_oauth=( - { - "appName": "STAC API", - "clientId": auth_settings.client_id, - "usePkceWithAuthorizationCodeGrant": True, - "scopes": "openid stac:item:create stac:item:update stac:item:delete stac:collection:create stac:collection:update stac:collection:delete", - } - if auth_settings.client_id - else {} - ), - lifespan=lifespan, - ), - title=f"{api_settings.project_name} STAC API", - description=api_settings.project_description, - settings=api_settings.load_postgres_settings(), - extensions=PgStacExtensions, - client=TenantAwareVedaCrudClient(post_request_model=POSTModel), - search_get_request_model=GETModel, - search_post_request_model=POSTModel, - items_get_request_model=items_get_request_model, - response_class=ORJSONResponse, - middlewares=[ - Middleware(CompressionMiddleware), - Middleware(ValidationMiddleware), - ], - router=APIRouter(route_class=LoggerRouteHandler), -) -app = api.app - -# Add tenant-specific routes -tenant_router = APIRouter(redirect_slashes=True) - -@tenant_router.get("/{tenant}/collections") -async def get_tenant_collections( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Get collections for a specific tenant.""" - logger.info(f"Getting collections for tenant: {tenant}") - collections = await api.client.all_collections(request, tenant=tenant) - - # Update links to include tenant prefix - if collections and isinstance(collections, dict): - collections = update_links_with_tenant(collections, tenant) - - return collections - - -@tenant_router.get("/{tenant}/collections/{collection_id}") -async def get_tenant_collection( - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - request: FastAPIRequest = None, -): - """Get a specific collection for a tenant.""" - logger.info(f"Getting collection {collection_id} for tenant: {tenant}") - collection = await api.client.get_collection(collection_id, request, tenant=tenant) - - # Update links to include tenant prefix - if collection and isinstance(collection, dict): - collection = update_links_with_tenant(collection, tenant) - - return collection - - -@tenant_router.get("/{tenant}/collections/{collection_id}/items") -async def get_tenant_collection_items( - request: FastAPIRequest, # It's good practice to have request as the first arg - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - limit: int = 10, # Add limit - token: Optional[str] = None, # Add token -): - """Get items from a collection for a specific tenant.""" - logger.info(f"Getting items from collection {collection_id} for tenant: {tenant}") - - # Pass the captured parameters to the client method - items = await api.client.item_collection( - collection_id=collection_id, - request=request, - tenant=tenant, - limit=limit, - token=token - ) - - # Your link updater will correctly rewrite the new `next` link if one is present - if items and isinstance(items, dict): - items = update_links_with_tenant(items, tenant) - - return items - - -@tenant_router.get("/{tenant}/collections/{collection_id}/items/{item_id}") -async def get_tenant_item( - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - item_id: str = Path(..., description="Item identifier"), - request: FastAPIRequest = None, -): - """Get a specific item for a tenant.""" - logger.info(f"======> Getting item {item_id} from collection {collection_id} for tenant: {tenant} <=====") - return await api.client.get_item(item_id, collection_id, request, tenant=tenant) - - -@tenant_router.get("/{tenant}/search") -async def get_tenant_search( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Search items for a specific tenant using GET.""" - logger.info(f"GET search for tenant: {tenant}") - return await api.client.get_search(request, tenant=tenant) - - -@tenant_router.get("/{tenant}/search") -async def get_tenant_search( - request: FastAPIRequest, # Request should come first - tenant: str = Path(..., description="Tenant identifier"), - # Add ALL possible GET search parameters here that stac-fastapi uses - collections: Optional[str] = None, - ids: Optional[str] = None, - bbox: Optional[str] = None, - datetime: Optional[str] = None, - limit: int = 10, - query: Optional[str] = None, - token: Optional[str] = None, - filter_lang: Optional[str] = None, - filter: Optional[str] = None, - sortby: Optional[str] = None, - # **kwargs: Any # Avoid using this if possible, be explicit -): - """Search items for a specific tenant using GET.""" - logger.info(f"GET search for tenant: {tenant}") - - # The base `get_search` method in stac-fastapi unpacks the request itself. - # What's important is that our `TenantAwareVedaCrudClient.get_search` can access these params. - # The default behavior of stac-fastapi `get_search` is to parse these from the request query params. - # Our modification in the client (step 1) will handle the tenant injection. - - # We create a dictionary of the GET parameters to pass them explicitly - # to avoid ambiguity. - params = { - "collections": collections.split(",") if collections else None, - "ids": ids.split(",") if ids else None, - "bbox": [float(x) for x in bbox.split(",")] if bbox else None, - "datetime": datetime, - "limit": limit, - "query": json.loads(query) if query else None, - "token": token, - "filter-lang": filter_lang, - "filter": json.loads(filter) if filter else None, - "sortby": sortby, - } - # Filter out None values - clean_params = {k: v for k, v in params.items() if v is not None} - - search_result = await api.client.get_search(request, tenant=tenant, **clean_params) - - # Update links - if search_result and isinstance(search_result, dict): - search_result = update_links_with_tenant(search_result, tenant) - - return search_result - - -@tenant_router.get("/{tenant}/") -async def get_tenant_landing_page( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Get landing page for a specific tenant.""" - logger.info(f"Getting landing page for tenant: {tenant}") - - # Get the base landing page by calling the method on the CLIENT, not the API object - # Corrected line: - base_landing = await api.client.landing_page(request=request) - - # The rest of your logic for modifying the links is correct - if isinstance(base_landing, ORJSONResponse): - # The client returns a response object, so we need to decode its content - body = base_landing.body - tenant_landing = json.loads(body) - - # Update links to include tenant prefix - if 'links' in tenant_landing: - # Using your update_links_with_tenant function is more robust - tenant_landing = update_links_with_tenant(tenant_landing, tenant) - - # Update title to include tenant - if 'title' in tenant_landing: - tenant_landing['title'] = f"{tenant.upper()} - {tenant_landing['title']}" - - # Return a new JSONResponse with the modified content - return ORJSONResponse(tenant_landing) - - # Fallback in case the response is not what we expect - return base_landing - -# Include the tenant router -app.include_router(tenant_router, tags=["Tenant-specific endpoints"]) - -# Add tenant-only enforcement middleware (set to False if you want to keep original routes) -# app.add_middleware(TenantOnlyMiddleware, enforce_tenant_only=True) - -# Set all CORS enabled origins -if api_settings.cors_origins: - app.add_middleware( - CORSMiddleware, - allow_origins=api_settings.cors_origins, - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "OPTIONS"], - allow_headers=["*"], - ) - -if api_settings.enable_transactions and auth_settings.client_id: - oidc_auth = OpenIdConnectAuth( - openid_configuration_url=auth_settings.openid_configuration_url, - allowed_jwt_audiences="account", - ) - - restricted_prefixes_methods = { - "/collections": [("POST", "stac:collection:create")], - "/collections/{collection_id}": [ - ("PUT", "stac:collection:update"), - ("DELETE", "stac:collection:delete"), - ], - "/collections/{collection_id}/items": [("POST", "stac:item:create")], - "/collections/{collection_id}/items/{item_id}": [ - ("PUT", "stac:item:update"), - ("DELETE", "stac:item:delete"), - ], - "/collections/{collection_id}/bulk_items": [("POST", "stac:item:create")], - } - - for route in app.router.routes: - method_scopes = restricted_prefixes_methods.get(route.path) - if not method_scopes: - continue - for method, scope in method_scopes: - if method not in route.methods: - continue - oidc_auth.apply_auth_dependencies(route, required_token_scopes=[scope]) - -if tiles_settings.titiler_endpoint: - # Register to the TiTiler extension to the api - extension = TiTilerExtension() - extension.register(api.app, tiles_settings.titiler_endpoint) - - -@app.get("/index.html", response_class=HTMLResponse) -async def viewer_page(request: Request): - """Search viewer.""" - path = api_settings.root_path or "" - return templates.TemplateResponse( - "stac-viewer.html", - {"request": request, "endpoint": str(request.url).replace("/index.html", path)}, - media_type="text/html", - ) - - -@app.get("/{tenant}/index.html", response_class=HTMLResponse) -async def tenant_viewer_page(request: Request, tenant: str): - """Tenant-specific search viewer.""" - return templates.TemplateResponse( - "stac-viewer.html", - { - "request": request, - "endpoint": str(request.url).replace("/index.html", f"/{tenant}"), - "tenant": tenant - }, - media_type="text/html", - ) - - -# If the correlation header is used in the UI, we can analyze traces that originate from a given user or client -@app.middleware("http") -async def add_correlation_id(request: Request, call_next): - """Add correlation ids to all requests and subsequent logs/traces""" - # Get correlation id from X-Correlation-Id header - corr_id = request.headers.get("x-correlation-id") - if not corr_id: - try: - # If empty, use request id from aws context - corr_id = request.scope["aws.context"].aws_request_id - except KeyError: - # If empty, use uuid - corr_id = "local" - # Add correlation id to logs - logger.set_correlation_id(corr_id) - # Add correlation id to traces - tracer.put_annotation(key="correlation_id", value=corr_id) - - response = await tracer.capture_method(call_next)(request) - # Return correlation header in response - response.headers["X-Correlation-Id"] = corr_id - logger.info("Request completed") - return response - - -@app.exception_handler(Exception) -async def validation_exception_handler(request, err): - """Handle exceptions that aren't caught elsewhere""" - metrics.add_metric(name="UnhandledExceptions", unit=MetricUnit.Count, value=1) - logger.error("Unhandled exception") - return JSONResponse(status_code=500, content={"detail": "Internal Server Error"}) diff --git a/stac_api/runtime/src/old_working_app.py b/stac_api/runtime/src/old_working_app.py deleted file mode 100644 index 1a9f18d0..00000000 --- a/stac_api/runtime/src/old_working_app.py +++ /dev/null @@ -1,565 +0,0 @@ -"""FastAPI application using PGStac with integrated tenant filtering. -Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac -""" - -import json -from contextlib import asynccontextmanager -from typing import Dict, Any, Optional - -from aws_lambda_powertools.metrics import MetricUnit -from src.config import TilesApiSettings, api_settings -from src.config import extensions as PgStacExtensions -from src.config import get_request_model as GETModel -from src.config import items_get_request_model -from src.config import post_request_model as POSTModel -from src.extension import TiTilerExtension - -from fastapi import APIRouter, FastAPI, Request as FastAPIRequest, Depends, Path, HTTPException -from fastapi.responses import ORJSONResponse -from stac_fastapi.pgstac.db import close_db_connection, connect_to_db -from starlette.middleware import Middleware -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.middleware.cors import CORSMiddleware -from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, Response -from starlette.templating import Jinja2Templates -from starlette.types import ASGIApp -from starlette_cramjam.middleware import CompressionMiddleware - -from .api import VedaStacApi -from .core import VedaCrudClient -from .monitoring import LoggerRouteHandler, logger, metrics, tracer -from .validation import ValidationMiddleware -import os -from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings - -try: - from importlib.resources import files as resources_files # type: ignore -except ImportError: - # Try backported to PY<39 `importlib_resources`. - from importlib_resources import files as resources_files # type: ignore - - -templates = Jinja2Templates(directory=str(resources_files(__package__) / "templates")) # type: ignore - -def update_links_with_tenant(data: dict, tenant: str) -> dict: - """Update all links in a response to include tenant prefix.""" - base_url = os.getenv('BASE_URL', "http://localhost:8081") - def update_link(link: dict): - if 'href' in link and not "token=next:" in link['href']: - href = link['href'] - if href.startswith(f'{base_url}/'): - # Replace base URL with tenant-prefixed URL - link['href'] = href.replace(f'{base_url}/', f'{base_url}/{tenant}/') - elif href.startswith('/') and not href.startswith(f'/{tenant}/'): - # Handle relative URLs - link['href'] = f'/{tenant}{href}' - - # Update main response links - if 'links' in data and isinstance(data['links'], list): - for link in data['links']: - update_link(link) - - # Update collection links - if 'collections' in data and isinstance(data['collections'], list): - for collection in data['collections']: - if 'links' in collection and isinstance(collection['links'], list): - for link in collection['links']: - update_link(link) - - # Update item/feature links - if 'features' in data and isinstance(data['features'], list): - for item in data['features']: - if 'links' in item and isinstance(item['links'], list): - for link in item['links']: - update_link(link) - - return data - - - -tiles_settings = TilesApiSettings() -auth_settings = OpenIdConnectSettings(_env_prefix="VEDA_STAC_") - - -class TenantAwareVedaCrudClient(VedaCrudClient): - """Extended CRUD client that applies tenant filtering.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - async def all_collections(self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get all collections with optional tenant filtering.""" - if tenant: - # Add tenant filter to the database query - # This assumes your collections table has a tenant column - # Adjust based on your actual database schema - logger.info(f"Filtering collections by tenant: {tenant}") - - # Call the parent method - collections = await super().all_collections(request, **kwargs) - - # If tenant is specified, filter the results - if tenant and hasattr(collections, 'collections'): - filtered_collections = [ - col for col in collections.collections - if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant - ] - collections.collections = filtered_collections - if hasattr(collections, 'context') and hasattr(collections.context, 'returned'): - collections.context.returned = len(filtered_collections) - elif tenant and isinstance(collections, dict) and 'collections' in collections: - filtered_collections = [ - col for col in collections['collections'] - if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant - ] - collections['collections'] = filtered_collections - if 'numberReturned' in collections: - collections['numberReturned'] = len(filtered_collections) - - return collections - - async def _validate_tenant_access(self, collection: dict, tenant: str, collection_id: str = ""): - """Raise HTTP 404 if the collection does not belong to the given tenant.""" - collection_tenant = collection.get("tenant") or collection.get("properties", {}).get("tenant") - if collection_tenant != tenant: - detail = f"Collection {collection_id} not found for tenant {tenant}" if collection_id else "Collection not found" - raise HTTPException(status_code=404, detail=detail) - - async def get_collection(self, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get collection with tenant filtering.""" - collection = await super().get_collection(collection_id, request, **kwargs) - - if tenant and collection: - await self._validate_tenant_access(collection, tenant, collection_id) - - return collection - - async def item_collection( - self, - collection_id: str, - request: FastAPIRequest, - tenant: Optional[str] = None, - limit: int = 10, # Add limit - token: Optional[str] = None, # Add token - **kwargs, - ): - """Get items with tenant filtering.""" - if tenant: - logger.info(f"Filtering items by tenant: {tenant} with token: {token}") - - # Your existing tenant validation logic is good - collection = await super().get_collection(collection_id, request, **kwargs) - if not collection: - raise HTTPException( - status_code=404, - detail=f"Collection {collection_id} not found for tenant {tenant}", - ) - await self._validate_tenant_access(collection, tenant, collection_id) - - # Pass the pagination parameters to the parent method - return await super().item_collection( - collection_id=collection_id, - request=request, - limit=limit, - token=token, - **kwargs - ) - async def get_item(self, item_id: str, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get item with tenant filtering.""" - if tenant: - logger.info(f"Filtering item {item_id} in collection {collection_id} by tenant: {tenant}") - - # Fetch and validate the collection belongs to the tenant - collection = await super().get_collection(collection_id, request, **kwargs) - if not collection: - raise HTTPException( - status_code=404, - detail=f"Collection {collection_id} not found for tenant {tenant}" - ) - await self._validate_tenant_access(collection, tenant, collection_id) - - return await super().get_item(item_id, collection_id, request, **kwargs) - async def post_search( - self, - search_request: POSTModel, - request: FastAPIRequest, - tenant: Optional[str] = None, - **kwargs - ): - """Search with tenant filtering.""" - if tenant: - logger.info(f"Filtering search by tenant: {tenant}") - # IMPORTANT: You must actually filter the search by the tenant. - # This assumes you have a 'tenant' property in your collection's 'properties'. - # pgstac search function can take a filter - tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} - - if search_request.filter: - # If a filter already exists, combine with an 'and' - search_request.filter = { - "op": "and", - "args": [ - search_request.filter, - tenant_filter, - ], - } - else: - search_request.filter = tenant_filter - search_request.filter_lang = "cql2-json" - - return await super().post_search(search_request, request, **kwargs) - - async def get_search( - self, - request: FastAPIRequest, - tenant: Optional[str] = None, - **kwargs, - ): - """GET search with tenant filtering.""" - if tenant: - logger.info(f"Filtering GET search by tenant: {tenant}") - # IMPORTANT: You must also modify the GET search to filter by tenant. - # This requires modifying the kwargs that will be used to build the search request. - tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} - - if "filter" in kwargs and kwargs["filter"]: - # Combine with existing filter - kwargs["filter"] = { - "op": "and", - "args": [ - kwargs["filter"], - tenant_filter, - ], - } - else: - kwargs["filter"] = tenant_filter - kwargs["filter-lang"] = "cql2-json" - - # The CoreCrudClient.get_search will use the modified kwargs - return await super().get_search(request, **kwargs) - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Get a database connection on startup, close it on shutdown.""" - await connect_to_db(app) - yield - await close_db_connection(app) - - -# Create the base STAC API -api = VedaStacApi( - app=FastAPI( - title=f"{api_settings.project_name} STAC API", - openapi_url="/openapi.json", - docs_url="/docs", - root_path=api_settings.root_path, - swagger_ui_init_oauth=( - { - "appName": "STAC API", - "clientId": auth_settings.client_id, - "usePkceWithAuthorizationCodeGrant": True, - "scopes": "openid stac:item:create stac:item:update stac:item:delete stac:collection:create stac:collection:update stac:collection:delete", - } - if auth_settings.client_id - else {} - ), - lifespan=lifespan, - ), - title=f"{api_settings.project_name} STAC API", - description=api_settings.project_description, - settings=api_settings.load_postgres_settings(), - extensions=PgStacExtensions, - client=TenantAwareVedaCrudClient(post_request_model=POSTModel), - search_get_request_model=GETModel, - search_post_request_model=POSTModel, - items_get_request_model=items_get_request_model, - response_class=ORJSONResponse, - middlewares=[ - Middleware(CompressionMiddleware), - Middleware(ValidationMiddleware), - ], - router=APIRouter(route_class=LoggerRouteHandler), -) -app = api.app - -# Add tenant-specific routes -tenant_router = APIRouter(redirect_slashes=True) - -@tenant_router.get("/{tenant}/collections") -async def get_tenant_collections( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Get collections for a specific tenant.""" - logger.info(f"Getting collections for tenant: {tenant}") - collections = await api.client.all_collections(request, tenant=tenant) - - # Update links to include tenant prefix - if collections and isinstance(collections, dict): - collections = update_links_with_tenant(collections, tenant) - - return collections - - -@tenant_router.get("/{tenant}/collections/{collection_id}") -async def get_tenant_collection( - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - request: FastAPIRequest = None, -): - """Get a specific collection for a tenant.""" - logger.info(f"Getting collection {collection_id} for tenant: {tenant}") - collection = await api.client.get_collection(collection_id, request, tenant=tenant) - - # Update links to include tenant prefix - if collection and isinstance(collection, dict): - collection = update_links_with_tenant(collection, tenant) - - return collection - - -@tenant_router.get("/{tenant}/collections/{collection_id}/items") -async def get_tenant_collection_items( - request: FastAPIRequest, # It's good practice to have request as the first arg - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - limit: int = 10, # Add limit - token: Optional[str] = None, # Add token -): - """Get items from a collection for a specific tenant.""" - logger.info(f"Getting items from collection {collection_id} for tenant: {tenant}") - - # Pass the captured parameters to the client method - items = await api.client.item_collection( - collection_id=collection_id, - request=request, - tenant=tenant, - limit=limit, - token=token - ) - - # Your link updater will correctly rewrite the new `next` link if one is present - if items and isinstance(items, dict): - items = update_links_with_tenant(items, tenant) - - return items - - -@tenant_router.get("/{tenant}/collections/{collection_id}/items/{item_id}") -async def get_tenant_item( - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - item_id: str = Path(..., description="Item identifier"), - request: FastAPIRequest = None, -): - """Get a specific item for a tenant.""" - logger.info(f"======> Getting item {item_id} from collection {collection_id} for tenant: {tenant} <=====") - - item = await api.client.get_item(item_id, collection_id, request, tenant=tenant) - if item and isinstance(item, dict): - item = update_links_with_tenant(item, tenant) - - return item - -@tenant_router.get("/{tenant}/search") -async def get_tenant_search( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Search items for a specific tenant using GET.""" - logger.info(f"GET search for tenant: {tenant}") - return await api.client.get_search(request, tenant=tenant) - - -@tenant_router.get("/{tenant}/search") -async def get_tenant_search( - request: FastAPIRequest, # Request should come first - tenant: str = Path(..., description="Tenant identifier"), - # Add ALL possible GET search parameters here that stac-fastapi uses - collections: Optional[str] = None, - ids: Optional[str] = None, - bbox: Optional[str] = None, - datetime: Optional[str] = None, - limit: int = 10, - query: Optional[str] = None, - token: Optional[str] = None, - filter_lang: Optional[str] = None, - filter: Optional[str] = None, - sortby: Optional[str] = None, - # **kwargs: Any # Avoid using this if possible, be explicit -): - """Search items for a specific tenant using GET.""" - logger.info(f"GET search for tenant: {tenant}") - - # The base `get_search` method in stac-fastapi unpacks the request itself. - # What's important is that our `TenantAwareVedaCrudClient.get_search` can access these params. - # The default behavior of stac-fastapi `get_search` is to parse these from the request query params. - # Our modification in the client (step 1) will handle the tenant injection. - - # We create a dictionary of the GET parameters to pass them explicitly - # to avoid ambiguity. - params = { - "collections": collections.split(",") if collections else None, - "ids": ids.split(",") if ids else None, - "bbox": [float(x) for x in bbox.split(",")] if bbox else None, - "datetime": datetime, - "limit": limit, - "query": json.loads(query) if query else None, - "token": token, - "filter-lang": filter_lang, - "filter": json.loads(filter) if filter else None, - "sortby": sortby, - } - # Filter out None values - clean_params = {k: v for k, v in params.items() if v is not None} - - search_result = await api.client.get_search(request, tenant=tenant, **clean_params) - - # Update links - if search_result and isinstance(search_result, dict): - search_result = update_links_with_tenant(search_result, tenant) - - return search_result - -@tenant_router.get("/{tenant}/index.html", response_class=HTMLResponse) -async def tenant_viewer_page(request: Request, tenant: str): - """Tenant-specific search viewer.""" - return templates.TemplateResponse( - "stac-viewer.html", - { - "request": request, - "endpoint": str(request.url).replace("/index.html", f"/{tenant}"), - "tenant": tenant - }, - media_type="text/html", - ) -@tenant_router.get("/{tenant}/") -async def get_tenant_landing_page( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Get landing page for a specific tenant.""" - logger.info(f"Getting landing page for tenant: {tenant}") - - # Get the base landing page by calling the method on the CLIENT, not the API object - # Corrected line: - base_landing = await api.client.landing_page(request=request) - - # The rest of your logic for modifying the links is correct - if isinstance(base_landing, ORJSONResponse): - # The client returns a response object, so we need to decode its content - body = base_landing.body - tenant_landing = json.loads(body) - - # Update links to include tenant prefix - if 'links' in tenant_landing: - # Using your update_links_with_tenant function is more robust - tenant_landing = update_links_with_tenant(tenant_landing, tenant) - - # Update title to include tenant - if 'title' in tenant_landing: - tenant_landing['title'] = f"{tenant.upper()} - {tenant_landing['title']}" - - # Return a new JSONResponse with the modified content - return ORJSONResponse(tenant_landing) - - # Fallback in case the response is not what we expect - return base_landing - -# Include the tenant router -app.include_router(tenant_router, tags=["Tenant-specific endpoints"]) - -# Add tenant-only enforcement middleware (set to False if you want to keep original routes) -# app.add_middleware(TenantOnlyMiddleware, enforce_tenant_only=True) - -# Set all CORS enabled origins -if api_settings.cors_origins: - app.add_middleware( - CORSMiddleware, - allow_origins=api_settings.cors_origins, - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "OPTIONS"], - allow_headers=["*"], - ) - -if api_settings.enable_transactions and auth_settings.client_id: - oidc_auth = OpenIdConnectAuth( - openid_configuration_url=auth_settings.openid_configuration_url, - allowed_jwt_audiences="account", - ) - - restricted_prefixes_methods = { - "/collections": [("POST", "stac:collection:create")], - "/collections/{collection_id}": [ - ("PUT", "stac:collection:update"), - ("DELETE", "stac:collection:delete"), - ], - "/collections/{collection_id}/items": [("POST", "stac:item:create")], - "/collections/{collection_id}/items/{item_id}": [ - ("PUT", "stac:item:update"), - ("DELETE", "stac:item:delete"), - ], - "/collections/{collection_id}/bulk_items": [("POST", "stac:item:create")], - } - - for route in app.router.routes: - method_scopes = restricted_prefixes_methods.get(route.path) - if not method_scopes: - continue - for method, scope in method_scopes: - if method not in route.methods: - continue - oidc_auth.apply_auth_dependencies(route, required_token_scopes=[scope]) - -if tiles_settings.titiler_endpoint: - # Register to the TiTiler extension to the api - extension = TiTilerExtension() - extension.register(api.app, tiles_settings.titiler_endpoint) - - -@app.get("/index.html", response_class=HTMLResponse) -async def viewer_page(request: Request): - """Search viewer.""" - path = api_settings.root_path or "" - return templates.TemplateResponse( - "stac-viewer.html", - {"request": request, "endpoint": str(request.url).replace("/index.html", path)}, - media_type="text/html", - ) - - - - - -# If the correlation header is used in the UI, we can analyze traces that originate from a given user or client -@app.middleware("http") -async def add_correlation_id(request: Request, call_next): - """Add correlation ids to all requests and subsequent logs/traces""" - # Get correlation id from X-Correlation-Id header - corr_id = request.headers.get("x-correlation-id") - if not corr_id: - try: - # If empty, use request id from aws context - corr_id = request.scope["aws.context"].aws_request_id - except KeyError: - # If empty, use uuid - corr_id = "local" - # Add correlation id to logs - logger.set_correlation_id(corr_id) - # Add correlation id to traces - tracer.put_annotation(key="correlation_id", value=corr_id) - - response = await tracer.capture_method(call_next)(request) - # Return correlation header in response - response.headers["X-Correlation-Id"] = corr_id - logger.info("Request completed") - return response - - -@app.exception_handler(Exception) -async def validation_exception_handler(request, err): - """Handle exceptions that aren't caught elsewhere""" - metrics.add_metric(name="UnhandledExceptions", unit=MetricUnit.Count, value=1) - logger.error("Unhandled exception") - return JSONResponse(status_code=500, content={"detail": "Internal Server Error"}) From c73c2b6089ff9906458223eeb96e825e586ae934 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:30:40 -0700 Subject: [PATCH 03/33] feat: add middleware, create tenant context in model --- stac_api/runtime/src/tenant_middleware.py | 66 +++++++++++++++++++++++ stac_api/runtime/src/tenant_models.py | 17 ++++++ 2 files changed, 83 insertions(+) create mode 100644 stac_api/runtime/src/tenant_middleware.py create mode 100644 stac_api/runtime/src/tenant_models.py diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py new file mode 100644 index 00000000..03da3e67 --- /dev/null +++ b/stac_api/runtime/src/tenant_middleware.py @@ -0,0 +1,66 @@ +import logging +from typing import Optional +from fastapi import Request, HTTPException +from starlette.middleware.base import BaseHTTPMiddleware + +from .tenant_models import TenantContext + +logger = logging.getLogger(__name__) + +class TenantMiddleware(BaseHTTPMiddleware): + def __init__(self, app): + super().__init__(app) + + async def dispatch(self, request: Request, call_next): + try: + tenant = self._extract_tenant(request) + + tenant_context = TenantContext( + tenant_id=tenant, + request_id=request.headers.get("X-Correlation-ID"), + ) if tenant else None + + request.state.tenant_context = tenant_context + + if tenant: + logger.info( + f"Tenant access: {tenant} for {request.method} {request.url.path}", + extra={ + "tenant": tenant, + "method": request.method, + "path": request.url.path if tenant_context else None + } + ) + + response = await call_next(request) + + if tenant_context: + response.headers["X-Tenant-ID"] = tenant_context.tenant_id + if tenant_context.request_id: + response.headers["X-Request-ID"] = tenant_context.request_id + + return response + + except Exception as e: + # JT TODO - Put more helpful exception? + logger.warning( + f"Tenant validation failed: {e.detail}", + extra={ + "tenant": getattr(e, 'tenant', None), + "resource_type": getattr(e, 'resource_type', None), + "resource_id": getattr(e, 'resource_id', None) + } + ) + raise HTTPException(status_code=404, detail=e.detail) + + except Exception as e: + logger.error(f"Tenant middleware error: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + def _extract_tenant(self, request: Request) -> Optional[str]: + path = request.url.path + if path.startswith('/api/stac/'): + path_parts = path.replace('/api/stac/', '').split('/') + if path_parts and path_parts[0]: + return path_parts[0] + return None \ No newline at end of file diff --git a/stac_api/runtime/src/tenant_models.py b/stac_api/runtime/src/tenant_models.py new file mode 100644 index 00000000..51c5c3bc --- /dev/null +++ b/stac_api/runtime/src/tenant_models.py @@ -0,0 +1,17 @@ +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field, field_validator +from fastapi import HTTPException + +class TenantContext(BaseModel): + tenant_id: str = Field(..., description="Tenant identifier") + request_id: Optional[str] = Field(None, description="Request correlation ID") + + @field_validator('tenant_id') + @classmethod + def validate_tenant_id(cls, v): + if not v or not v.strip(): + raise ValueError("Tenant ID cannot be empty") + if len(v) > 100: + raise ValueError("Tenant ID too long") + return v.strip().lower() + From f60c812d952e4788f0544487e9b8f5929afcb27b Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:36:18 -0700 Subject: [PATCH 04/33] feat: add TenantSearchRequest model --- stac_api/runtime/src/tenant_models.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/stac_api/runtime/src/tenant_models.py b/stac_api/runtime/src/tenant_models.py index 51c5c3bc..cd27722b 100644 --- a/stac_api/runtime/src/tenant_models.py +++ b/stac_api/runtime/src/tenant_models.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, field_validator -from fastapi import HTTPException class TenantContext(BaseModel): tenant_id: str = Field(..., description="Tenant identifier") @@ -15,3 +14,15 @@ def validate_tenant_id(cls, v): raise ValueError("Tenant ID too long") return v.strip().lower() + +class TenantSearchRequest(BaseModel): + """Tenant-aware search request model.""" + + tenant: Optional[str] = Field(None, description="Tenant identifier") + collections: Optional[List[str]] = Field(None, description="Collection IDs to search") + bbox: Optional[List[float]] = Field(None, description="Bounding box") + datetime: Optional[str] = Field(None, description="Datetime range") + limit: int = Field(10, description="Maximum number of results") + token: Optional[str] = Field(None, description="Pagination token") + filter: Optional[Dict[str, Any]] = Field(None, description="CQL2 filter") + filter_lang: str = Field("cql2-text", description="Filter language") From 9ddbb2e1ff0ccb6b02e51fba811e42be2f5cc153 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:38:33 -0700 Subject: [PATCH 05/33] feat: add specific tenant validation error class --- stac_api/runtime/src/tenant_models.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/stac_api/runtime/src/tenant_models.py b/stac_api/runtime/src/tenant_models.py index cd27722b..9790bc95 100644 --- a/stac_api/runtime/src/tenant_models.py +++ b/stac_api/runtime/src/tenant_models.py @@ -26,3 +26,22 @@ class TenantSearchRequest(BaseModel): token: Optional[str] = Field(None, description="Pagination token") filter: Optional[Dict[str, Any]] = Field(None, description="CQL2 filter") filter_lang: str = Field("cql2-text", description="Filter language") + +class TenantValidationError(HTTPException): + def __init__( + self, + resource_type: str, + resource_id: str, + tenant: str, + actual_tenant: Optional[str] = None + ): + self.resource_type = resource_type + self.resource_id = resource_id + self.tenant = tenant + self.actual_tenant = actual_tenant + + detail = f"{resource_type} {resource_id} not found for tenant {tenant}" + if actual_tenant: + detail += f" (found tenant: {actual_tenant})" + + super().__init__(status_code=404, detail=detail) \ No newline at end of file From 223088a99fa2c68decf9a3b2acd76d1dc98756a4 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:41:11 -0700 Subject: [PATCH 06/33] fix: update error handling --- stac_api/runtime/src/tenant_middleware.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py index 03da3e67..8b51f27f 100644 --- a/stac_api/runtime/src/tenant_middleware.py +++ b/stac_api/runtime/src/tenant_middleware.py @@ -3,7 +3,7 @@ from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware -from .tenant_models import TenantContext +from .tenant_models import TenantContext, TenantValidationError logger = logging.getLogger(__name__) @@ -41,8 +41,7 @@ async def dispatch(self, request: Request, call_next): return response - except Exception as e: - # JT TODO - Put more helpful exception? + except TenantValidtionError as e: logger.warning( f"Tenant validation failed: {e.detail}", extra={ From 4008178f4ddb08d23454ab04a431ec2b7ba1e58e Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:47:32 -0700 Subject: [PATCH 07/33] feat: add functon for adding tenant filter to search --- stac_api/runtime/src/tenant_models.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/stac_api/runtime/src/tenant_models.py b/stac_api/runtime/src/tenant_models.py index 9790bc95..e3bdf898 100644 --- a/stac_api/runtime/src/tenant_models.py +++ b/stac_api/runtime/src/tenant_models.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, field_validator +from fastapi import HTTPException class TenantContext(BaseModel): tenant_id: str = Field(..., description="Tenant identifier") @@ -27,6 +28,29 @@ class TenantSearchRequest(BaseModel): filter: Optional[Dict[str, Any]] = Field(None, description="CQL2 filter") filter_lang: str = Field("cql2-text", description="Filter language") + def add_tenant_filter(self, tenant: str) -> None: + """Add tenant filter to the search request.""" + if not tenant: + return + + # Create tenant filter for properties.tenant + tenant_filter = { + "op": "=", + "args": [ + {"property": "tenant"}, + tenant + ] + } + + # If there's already a filter, combine using AND + if self.filter: + self.filter = { + "op": "and", + "args": [self.filter, tenant_filter] + } + else: + self.filter = tenant_filter + class TenantValidationError(HTTPException): def __init__( self, From bda72b29591a18ecd5c6296e60e40cce2c46a921 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:59:51 -0700 Subject: [PATCH 08/33] fix: add missing conf field for search request --- stac_api/runtime/src/tenant_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stac_api/runtime/src/tenant_models.py b/stac_api/runtime/src/tenant_models.py index e3bdf898..e65ad353 100644 --- a/stac_api/runtime/src/tenant_models.py +++ b/stac_api/runtime/src/tenant_models.py @@ -27,6 +27,7 @@ class TenantSearchRequest(BaseModel): token: Optional[str] = Field(None, description="Pagination token") filter: Optional[Dict[str, Any]] = Field(None, description="CQL2 filter") filter_lang: str = Field("cql2-text", description="Filter language") + conf: Optional[Dict] = None def add_tenant_filter(self, tenant: str) -> None: """Add tenant filter to the search request.""" From 0f789f2b0364eb82712a1b4c54b72a02d3bd1a73 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Mon, 15 Sep 2025 18:08:38 -0700 Subject: [PATCH 09/33] fix: move code from app to tenant_client and tenant_routes --- stac_api/runtime/src/app.py | 311 +------------------------- stac_api/runtime/src/tenant_client.py | 184 +++++++++++++++ stac_api/runtime/src/tenant_routes.py | 258 +++++++++++++++++++++ 3 files changed, 453 insertions(+), 300 deletions(-) create mode 100644 stac_api/runtime/src/tenant_client.py create mode 100644 stac_api/runtime/src/tenant_routes.py diff --git a/stac_api/runtime/src/app.py b/stac_api/runtime/src/app.py index d5e8e3a1..7ab9ccf0 100644 --- a/stac_api/runtime/src/app.py +++ b/stac_api/runtime/src/app.py @@ -28,6 +28,9 @@ from .core import VedaCrudClient from .monitoring import LoggerRouteHandler, logger, metrics, tracer from .validation import ValidationMiddleware +from .tenant_client import TenantAwareVedaCrudClient +from .tenant_middleware import TenantMiddleware +from .tenant_routes import create_tenant_router import os from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings @@ -43,153 +46,6 @@ tiles_settings = TilesApiSettings() auth_settings = OpenIdConnectSettings(_env_prefix="VEDA_STAC_") - -class TenantAwareVedaCrudClient(VedaCrudClient): - """Extended CRUD client that applies tenant filtering.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - async def all_collections(self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get all collections with optional tenant filtering.""" - - # Call the parent method - collections = await super().all_collections(request, **kwargs) - - # If tenant is specified, filter the results - if tenant and hasattr(collections, 'collections'): - filtered_collections = [ - col for col in collections.collections - if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant - ] - collections.collections = filtered_collections - if hasattr(collections, 'context') and hasattr(collections.context, 'returned'): - collections.context.returned = len(filtered_collections) - elif tenant and isinstance(collections, dict) and 'collections' in collections: - filtered_collections = [ - col for col in collections['collections'] - if col.get('tenant') == tenant or col.get('properties', {}).get('tenant') == tenant - ] - collections['collections'] = filtered_collections - if 'numberReturned' in collections: - collections['numberReturned'] = len(filtered_collections) - - return collections - - async def _validate_tenant_access(self, collection: dict, tenant: str, collection_id: str = ""): - """Raise HTTP 404 if the collection does not belong to the given tenant.""" - collection_tenant = collection.get("tenant") or collection.get("properties", {}).get("tenant") - if collection_tenant != tenant: - detail = f"Collection {collection_id} not found for tenant {tenant}" if collection_id else "Collection not found" - raise HTTPException(status_code=404, detail=detail) - - async def get_collection(self, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get collection with tenant filtering.""" - collection = await super().get_collection(collection_id, request, **kwargs) - - if tenant and collection: - await self._validate_tenant_access(collection, tenant, collection_id) - - return collection - - async def item_collection( - self, - collection_id: str, - request: FastAPIRequest, - tenant: Optional[str] = None, - limit: int = 10, # Add limit - token: Optional[str] = None, # Add token - **kwargs, - ): - """Get items with tenant filtering.""" - if tenant: - logger.info(f"Filtering items by tenant: {tenant} with token: {token}") - - # Your existing tenant validation logic is good - collection = await super().get_collection(collection_id, request, **kwargs) - if not collection: - raise HTTPException( - status_code=404, - detail=f"Collection {collection_id} not found for tenant {tenant}", - ) - await self._validate_tenant_access(collection, tenant, collection_id) - - # Pass the pagination parameters to the parent method - return await super().item_collection( - collection_id=collection_id, - request=request, - limit=limit, - token=token, - **kwargs - ) - async def get_item(self, item_id: str, collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs): - """Get item with tenant filtering.""" - if tenant: - logger.info(f"Filtering item {item_id} in collection {collection_id} by tenant: {tenant}") - - # Fetch and validate the collection belongs to the tenant - collection = await super().get_collection(collection_id, request, **kwargs) - if not collection: - raise HTTPException( - status_code=404, - detail=f"Collection {collection_id} not found for tenant {tenant}" - ) - await self._validate_tenant_access(collection, tenant, collection_id) - - return await super().get_item(item_id, collection_id, request, **kwargs) - async def post_search( - self, - search_request, - request: FastAPIRequest, - tenant: Optional[str] = None, - **kwargs - ): - """Search with tenant filtering.""" - if tenant: - logger.info(f"Filtering search by tenant: {tenant}") - tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} - - if search_request.filter: - # If a filter already exists, combine with an 'and' - search_request.filter = { - "op": "and", - "args": [ - search_request.filter, - tenant_filter, - ], - } - else: - search_request.filter = tenant_filter - search_request.filter_lang = "cql2-json" - - return await super().post_search(search_request, request, **kwargs) - - async def get_search( - self, - request: FastAPIRequest, - tenant: Optional[str] = None, - **kwargs, - ): - """GET search with tenant filtering.""" - if tenant: - tenant_filter = {"op": "=", "args": [{"property": "collection"}, {"property": "tenant"}, tenant]} - - if "filter" in kwargs and kwargs["filter"]: - # Combine with existing filter - kwargs["filter"] = { - "op": "and", - "args": [ - kwargs["filter"], - tenant_filter, - ], - } - else: - kwargs["filter"] = tenant_filter - kwargs["filter-lang"] = "cql2-json" - - # The CoreCrudClient.get_search will use the modified kwargs - return await super().get_search(request, **kwargs) - @asynccontextmanager async def lifespan(app: FastAPI): """Get a database connection on startup, close it on shutdown.""" @@ -197,6 +53,7 @@ async def lifespan(app: FastAPI): yield await close_db_connection(app) +tenant_client = TenantAwareVedaCrudClient(pgstac_search_model=POSTModel) api = VedaStacApi( app=FastAPI( @@ -220,7 +77,7 @@ async def lifespan(app: FastAPI): description=api_settings.project_description, settings=api_settings, extensions=PgStacExtensions, - client=TenantAwareVedaCrudClient(pgstac_search_model=POSTModel), + client=tenant_client, search_get_request_model=GETModel, search_post_request_model=POSTModel, items_get_request_model=items_get_request_model, @@ -228,164 +85,18 @@ async def lifespan(app: FastAPI): middlewares=[ Middleware(CompressionMiddleware), Middleware(ValidationMiddleware), + Middleware(TenantMiddleware), ], router=APIRouter(route_class=LoggerRouteHandler), ) app = api.app # Add tenant-specific routes -tenant_router = APIRouter(redirect_slashes=True) - -@tenant_router.get("/{tenant}/collections") -async def get_tenant_collections( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Get collections for a specific tenant.""" - logger.info(f"Getting collections for tenant: {tenant}") - collections = await api.client.all_collections(request, tenant=tenant) - - return collections - - -@tenant_router.get("/{tenant}/collections/{collection_id}") -async def get_tenant_collection( - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - request: FastAPIRequest = None, -): - """Get a specific collection for a tenant.""" - logger.info(f"Getting collection {collection_id} for tenant: {tenant}") - collection = await api.client.get_collection(collection_id, request, tenant=tenant) - - return collection - - -@tenant_router.get("/{tenant}/collections/{collection_id}/items") -async def get_tenant_collection_items( - request: FastAPIRequest, # It's good practice to have request as the first arg - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - limit: int = 10, # Add limit - token: Optional[str] = None, # Add token -): - """Get items from a collection for a specific tenant.""" - logger.info(f"Getting items from collection {collection_id} for tenant: {tenant}") - - # Pass the captured parameters to the client method - items = await api.client.item_collection( - collection_id=collection_id, - request=request, - tenant=tenant, - limit=limit, - token=token - ) - - return items - - -@tenant_router.get("/{tenant}/collections/{collection_id}/items/{item_id}") -async def get_tenant_item( - tenant: str = Path(..., description="Tenant identifier"), - collection_id: str = Path(..., description="Collection identifier"), - item_id: str = Path(..., description="Item identifier"), - request: FastAPIRequest = None, -): - """Get a specific item for a tenant.""" - logger.info(f"======> Getting item {item_id} from collection {collection_id} for tenant: {tenant} <=====") - return await api.client.get_item(item_id, collection_id, request, tenant=tenant) - - -@tenant_router.get("/{tenant}/search") -async def get_tenant_search( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Search items for a specific tenant using GET.""" - logger.info(f"GET search for tenant: {tenant}") - return await api.client.get_search(request, tenant=tenant) - - -@tenant_router.get("/{tenant}/search") -async def get_tenant_search( - request: FastAPIRequest, # Request should come first - tenant: str = Path(..., description="Tenant identifier"), - # Add ALL possible GET search parameters here that stac-fastapi uses - collections: Optional[str] = None, - ids: Optional[str] = None, - bbox: Optional[str] = None, - datetime: Optional[str] = None, - limit: int = 10, - query: Optional[str] = None, - token: Optional[str] = None, - filter_lang: Optional[str] = None, - filter: Optional[str] = None, - sortby: Optional[str] = None, - # **kwargs: Any # Avoid using this if possible, be explicit -): - """Search items for a specific tenant using GET.""" - logger.info(f"GET search for tenant: {tenant}") - - # The base `get_search` method in stac-fastapi unpacks the request itself. - # What's important is that our `TenantAwareVedaCrudClient.get_search` can access these params. - # The default behavior of stac-fastapi `get_search` is to parse these from the request query params. - # Our modification in the client (step 1) will handle the tenant injection. - - # We create a dictionary of the GET parameters to pass them explicitly - # to avoid ambiguity. - params = { - "collections": collections.split(",") if collections else None, - "ids": ids.split(",") if ids else None, - "bbox": [float(x) for x in bbox.split(",")] if bbox else None, - "datetime": datetime, - "limit": limit, - "query": json.loads(query) if query else None, - "token": token, - "filter-lang": filter_lang, - "filter": json.loads(filter) if filter else None, - "sortby": sortby, - } - # Filter out None values - clean_params = {k: v for k, v in params.items() if v is not None} - - search_result = await api.client.get_search(request, tenant=tenant, **clean_params) - - return search_result - - -@tenant_router.get("/{tenant}/") -async def get_tenant_landing_page( - tenant: str = Path(..., description="Tenant identifier"), - request: FastAPIRequest = None, -): - """Get landing page for a specific tenant.""" - logger.info(f"Getting landing page for tenant: {tenant}") - - # Get the base landing page by calling the method on the CLIENT, not the API object - # Corrected line: - base_landing = await api.client.landing_page(request=request) - - # The rest of your logic for modifying the links is correct - if isinstance(base_landing, ORJSONResponse): - # The client returns a response object, so we need to decode its content - body = base_landing.body - tenant_landing = json.loads(body) - - # Update title to include tenant - if 'title' in tenant_landing: - tenant_landing['title'] = f"{tenant.upper()} - {tenant_landing['title']}" - - # Return a new JSONResponse with the modified content - return ORJSONResponse(tenant_landing) - - # Fallback in case the response is not what we expect - return base_landing - -# Include the tenant router +logger.info("Creating tenant router...") +tenant_router = create_tenant_router(tenant_client) +logger.info(f"Registering tenant router with {len(tenant_router.routes)} routes") app.include_router(tenant_router, tags=["Tenant-specific endpoints"]) - -# Add tenant-only enforcement middleware (set to False if you want to keep original routes) -# app.add_middleware(TenantOnlyMiddleware, enforce_tenant_only=True) +logger.info("Tenant router registered successfully") # Set all CORS enabled origins if api_settings.cors_origins: @@ -449,7 +160,7 @@ async def tenant_viewer_page(request: Request, tenant: str): return templates.TemplateResponse( "stac-viewer.html", { - "request": request, + "request": request, "endpoint": str(request.url).replace("/index.html", f"/{tenant}"), "tenant": tenant }, diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py new file mode 100644 index 00000000..1bec8176 --- /dev/null +++ b/stac_api/runtime/src/tenant_client.py @@ -0,0 +1,184 @@ + +from typing import Any, Dict, Optional, Union + +from fastapi import HTTPException, Request as FastAPIRequest +from stac_fastapi.types.stac import Item, ItemCollection, Collections, Collection, LandingPage +from starlette.requests import Request +from urllib.parse import urlparse + +from .core import VedaCrudClient +from .tenant_models import TenantValidationError + +class TenantAwareVedaCrudClient(VedaCrudClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_tenant_from_request(self, request: Request) -> Optional[str]: + if hasattr(request, 'path_params') and 'tenant' in request.path_params: + return request.path_params['tenant'] + return None + + async def get_tenant_collections( + self, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + collections = await super().all_collections(request, **kwargs) + + collections_dict = collections + + if tenant and isinstance(collections_dict, dict) and 'collections' in collections_dict: + filtered_collections = [ + col for col in collections_dict['collections'] + if col.get('properties', {}).get('tenant') == tenant + ] + collections_dict['collections'] = filtered_collections + if 'numberReturned' in collections_dict: + collections_dict['numberReturned'] = len(filtered_collections) + + return collections_dict + + async def get_collection( + self, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs + ) -> Collection: + collection = await super().get_collection(collection_id, request, **kwargs) + + if tenant and collection: + self.validate_tenant_access(collection, tenant, collection_id) + + return collection + + async def item_collection( + self, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + limit: int = 10, + token: Optional[str] = None, + **kwargs, + ) -> ItemCollection: + if tenant: + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}" + ) + self.validate_tenant_access(collection, tenant, collection_id) + + return await super().item_collection( + collection_id=collection_id, + request=request, + limit=limit, + token=token, + **kwargs + ) + + async def get_item( + self, + item_id: str, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs + ) -> Item: + if tenant: + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}" + ) + self.validate_tenant_access(collection, tenant, collection_id) + + return await super().get_item(item_id, collection_id, request, **kwargs) + + async def post_search( + self, + search_request, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs + ) -> ItemCollection: + result = await super().post_search(search_request, request, **kwargs) + + if tenant: + result = self._filter_search_results_by_tenant(result, tenant) + + return result + + async def get_search( + self, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs, + ) -> ItemCollection: + result = await super().get_search(request, **kwargs) + + if tenant: + result = self._filter_search_results_by_tenant(result, tenant) + + return result + + def _filter_search_results_by_tenant( + self, + result: ItemCollection, + tenant: str + ) -> ItemCollection: + if isinstance(result, dict) and 'features' in result: + filtered_features = [ + feature for feature in result['features'] + if feature.get('properties', {}).get('tenant') == tenant + ] + result['features'] = filtered_features + if 'numberReturned' in result: + result['numberReturned'] = len(filtered_features) + + return result + + async def landing_page( + self, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs + ) -> LandingPage: + landing_page = await super().landing_page(request=request, **kwargs) + + if tenant: + landing_page = self._customize_landing_page_for_tenant(landing_page, tenant) + + return landing_page + + def _customize_landing_page_for_tenant(self, landing_page: LandingPage, tenant: str) -> LandingPage: + if 'title' in landing_page: + landing_page['title'] = f"{tenant.upper()} - {landing_page['title']}" + + if 'links' in landing_page: + for link in landing_page['links']: + if 'href' in link: + href = link['href'] + + skip_links = ['self', 'root', 'service-desc', 'service-doc'] + if link.get('link') in skip_links: + continue + + if href.startswith('http'): + parsed = urlparse(href) + path_parts = parsed.path.split('/') + # a URL should follow this structure scheme://netloc/path;parameters?query#fragment generally + # source: https://docs.python.org/3/library/urllib.parse.html + if len(path_parts) > 2 and path_parts[1] == 'api' and path_parts[2] == 'stac': + new_path = '/'.join(path_parts[:3]) + link['href'] = f"{parsed.scheme}://{parsed.netloc}{new_path}" + else: + link['href'] = f"/{tenant}{href}" + + if 'href' in link and not link['href'].startswith('http'): + link['href'] = f"/{tenant}{link['href']}" + + return landing_page diff --git a/stac_api/runtime/src/tenant_routes.py b/stac_api/runtime/src/tenant_routes.py new file mode 100644 index 00000000..a13d8f45 --- /dev/null +++ b/stac_api/runtime/src/tenant_routes.py @@ -0,0 +1,258 @@ +import json +import logging +from typing import Any, Dict, Optional + +from fastapi import APIRouter, Request, Path, Query, HTTPException +from fastapi.responses import ORJSONResponse +from stac_fastapi.types.stac import Item, ItemCollection + +from .tenant_client import TenantAwareVedaCrudClient +from .tenant_models import TenantSearchRequest + + +logger = logging.getLogger(__name__) + +class TenantRouteHandler: + def __init__(self, client: TenantAwareVedaCrudClient): + self.client = client + + async def get_tenant_collections( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + ) -> Dict[str, Any]: + logger.info(f"Getting collections for tenant: {tenant}") + + try: + collections = await self.client.get_tenant_collections(request, tenant=tenant) + return collections + except Exception as e: + logger.error(f"Error getting collections for tenant {tenant}: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_collection( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + ) -> Dict: + logger.info(f"Getting collection {collection_id} for tenant: {tenant}") + + try: + collection = await self.client.get_collection( + collection_id, request, tenant=tenant + ) + return collection + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error getting collection {collection_id} for tenant {tenant}: {str(e)}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_collection_items( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + limit: int = Query(10, description="Maximum number of items to return"), + token: Optional[str] = Query(None, description="Pagination token"), + ) -> ItemCollection: + logger.info(f"Getting items from collection {collection_id} for tenant: {tenant}") + + try: + items = await self.client.item_collection( + collection_id=collection_id, + request=request, + tenant=tenant, + limit=limit, + token=token + ) + return items + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error getting items from collection {collection_id} for tenant {tenant}: {str(e)}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_item( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + item_id: str = Path(..., description="Item identifier"), + ) -> Item: + logger.info(f"Getting item {item_id} from collection {collection_id} for tenant: {tenant}") + + try: + item = await self.client.get_item( + item_id, collection_id, request, tenant=tenant + ) + return item + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error getting item {item_id} from collection {collection_id} for tenant {tenant}: {str(e)}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_search( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + collections: Optional[str] = Query(None, description="Comma-separated list of collection IDs"), + ids: Optional[str] = Query(None, description="Comma-separated list of item IDs"), + bbox: Optional[str] = Query(None, description="Bounding box"), + datetime: Optional[str] = Query(None, description="Datetime range"), + limit: int = Query(10, description="Maximum number of results"), + query: Optional[str] = Query(None, description="Query parameters"), + token: Optional[str] = Query(None, description="Pagination token"), + filter_lang: Optional[str] = Query("cql2-text", description="Filter language"), + filter: Optional[str] = Query(None, description="CQL2 filter"), + sortby: Optional[str] = Query(None, description="Sort parameters"), + ) -> ItemCollection: + """Search items for a specific tenant using GET.""" + logger.info(f"GET search for tenant: {tenant}") + + try: + search_params = { + "collections": collections.split(",") if collections else None, + "ids": ids.split(",") if ids else None, + "bbox": [float(x) for x in bbox.split(",")] if bbox else None, + "datetime": datetime, + "limit": limit, + "query": json.loads(query) if query else None, + "token": token, + "filter-lang": filter_lang, + "filter": json.loads(filter) if filter else None, + "sortby": sortby, + } + + clean_params = {k: v for k, v in search_params.items() if v is not None} + + search_result = await self.client.get_search( + request, tenant=tenant, **clean_params + ) + + return search_result + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON in search parameters: {str(e)}") + raise HTTPException(status_code=400, detail="Invalid JSON in search parameters") + except HTTPException: + raise + except Exception as e: + logger.error(f"Error performing search for tenant {tenant}: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + async def post_tenant_search( + self, + search_request: TenantSearchRequest, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + ) -> ItemCollection: + """Search items for a specific tenant using POST.""" + logger.info(f"POST search for tenant: {tenant}") + + try: + search_request.add_tenant_filter(tenant) + + search_result = await self.client.post_search( + search_request, request, tenant=tenant + ) + + return search_result + except HTTPException: + raise + except Exception as e: + logger.error(f"Error performing POST search for tenant {tenant}: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_landing_page( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + ) -> ORJSONResponse: + """Get landing page for a specific tenant.""" + logger.info(f"Getting landing page for tenant: {tenant}") + + try: + landing_page = await self.client.landing_page(request, tenant=tenant) + return ORJSONResponse(content=landing_page) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting landing page for tenant {tenant}: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + +def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: + """Create tenant-specific router """ + + router = APIRouter(redirect_slashes=True) + handler = TenantRouteHandler(client) + + logger.info("Creating tenant router with routes") + + router.add_api_route( + "/{tenant}/collections", + handler.get_tenant_collections, + methods=["GET"], + summary="Get collections for tenant", + description="Retrieve all collections for a specific tenant" + ) + + router.add_api_route( + "/{tenant}/collections/{collection_id}", + handler.get_tenant_collection, + methods=["GET"], + summary="Get tenant collection", + description="Retrieve a specific collection for a tenant" + ) + + router.add_api_route( + "/{tenant}/collections/{collection_id}/items", + handler.get_tenant_collection_items, + methods=["GET"], + summary="Get tenant collection items", + description="Retrieve items from a collection for a tenant" + ) + + router.add_api_route( + "/{tenant}/collections/{collection_id}/items/{item_id}", + handler.get_tenant_item, + methods=["GET"], + summary="Get tenant item", + description="Retrieve a specific item for a tenant" + ) + + # Search endpoints + router.add_api_route( + "/{tenant}/search", + handler.get_tenant_search, + methods=["GET"], + summary="Search tenant items (GET)", + description="Search items for a tenant using GET method" + ) + + router.add_api_route( + "/{tenant}/search", + handler.post_tenant_search, + methods=["POST"], + summary="Search tenant items (POST)", + description="Search items for a tenant using POST method" + ) + + router.add_api_route( + "/{tenant}/", + handler.get_tenant_landing_page, + methods=["GET"], + summary="Get tenant landing page", + description="Retrieve landing page for a specific tenant" + ) + + logger.info(f"Created tenant router with {len(router.routes)} routes") + return router From 0974e7497f263f931e6e5a94a00ef8f1ece6034a Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 09:20:25 -0700 Subject: [PATCH 10/33] fix: update post model which broke from merge changes --- stac_api/runtime/src/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_api/runtime/src/app.py b/stac_api/runtime/src/app.py index 436f3f6a..98ae3efd 100644 --- a/stac_api/runtime/src/app.py +++ b/stac_api/runtime/src/app.py @@ -57,7 +57,7 @@ async def lifespan(app: FastAPI): yield await close_db_connection(app) -tenant_client = TenantAwareVedaCrudClient(pgstac_search_model=POSTModel) +tenant_client = TenantAwareVedaCrudClient(pgstac_search_model=post_request_model) api = StacApi( app=FastAPI( From 7dd07473dd833a351eb73a2c3ad3308f371d9f60 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 09:20:42 -0700 Subject: [PATCH 11/33] fix: landing page links generation --- stac_api/runtime/src/tenant_client.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index 1bec8176..bdbbc81c 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -1,4 +1,4 @@ - +import logging from typing import Any, Dict, Optional, Union from fastapi import HTTPException, Request as FastAPIRequest @@ -9,6 +9,8 @@ from .core import VedaCrudClient from .tenant_models import TenantValidationError +logger = logging.getLogger(__name__) + class TenantAwareVedaCrudClient(VedaCrudClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -160,6 +162,7 @@ def _customize_landing_page_for_tenant(self, landing_page: LandingPage, tenant: if 'links' in landing_page: for link in landing_page['links']: + logger.info("Inspecting links to inject tenant...") if 'href' in link: href = link['href'] @@ -172,13 +175,13 @@ def _customize_landing_page_for_tenant(self, landing_page: LandingPage, tenant: path_parts = parsed.path.split('/') # a URL should follow this structure scheme://netloc/path;parameters?query#fragment generally # source: https://docs.python.org/3/library/urllib.parse.html - if len(path_parts) > 2 and path_parts[1] == 'api' and path_parts[2] == 'stac': - new_path = '/'.join(path_parts[:3]) + if len(path_parts) >= 3 and path_parts[1] == 'api' and path_parts[2] == 'stac': + new_path_parts = path_parts[:3] + [tenant] + path_parts[3:] + new_path = '/'.join(new_path_parts) link['href'] = f"{parsed.scheme}://{parsed.netloc}{new_path}" else: - link['href'] = f"/{tenant}{href}" + if href.startswith('/api/stac'): + link['href'] = href.replace('/api/stac', f'/api/stac/{tenant}') - if 'href' in link and not link['href'].startswith('http'): - link['href'] = f"/{tenant}{link['href']}" return landing_page From 43f4e2616cd9c7f0d271f8a8eaa0210e4b0f832b Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 12:50:14 -0700 Subject: [PATCH 12/33] fix: docstrings, formatting, and unit tests --- stac_api/runtime/src/app.py | 13 +- stac_api/runtime/src/tenant_client.py | 224 +++++++++++++++++----- stac_api/runtime/src/tenant_middleware.py | 37 ++-- stac_api/runtime/src/tenant_models.py | 27 ++- stac_api/runtime/src/tenant_routes.py | 62 ++++-- stac_api/runtime/tests/conftest.py | 68 ++++++- stac_api/runtime/tests/test_extensions.py | 22 ++- 7 files changed, 339 insertions(+), 114 deletions(-) diff --git a/stac_api/runtime/src/app.py b/stac_api/runtime/src/app.py index 98ae3efd..708908ce 100644 --- a/stac_api/runtime/src/app.py +++ b/stac_api/runtime/src/app.py @@ -2,9 +2,7 @@ Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac """ -import json from contextlib import asynccontextmanager -from typing import Dict, Any, Optional from aws_lambda_powertools.metrics import MetricUnit from src.config import ( @@ -18,7 +16,7 @@ ) from src.extension import TiTilerExtension -from fastapi import APIRouter, FastAPI, Request as FastAPIRequest, Depends, Path, HTTPException +from fastapi import APIRouter, FastAPI from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi from stac_fastapi.pgstac.db import close_db_connection, connect_to_db @@ -29,13 +27,12 @@ from starlette.templating import Jinja2Templates from starlette_cramjam.middleware import CompressionMiddleware -from .core import VedaCrudClient from .monitoring import LoggerRouteHandler, logger, metrics, tracer -from .validation import ValidationMiddleware from .tenant_client import TenantAwareVedaCrudClient from .tenant_middleware import TenantMiddleware from .tenant_routes import create_tenant_router -import os +from .validation import ValidationMiddleware + from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings try: @@ -50,6 +47,7 @@ tiles_settings = TilesApiSettings() auth_settings = OpenIdConnectSettings(_env_prefix="VEDA_STAC_") + @asynccontextmanager async def lifespan(app: FastAPI): """Get a database connection on startup, close it on shutdown.""" @@ -57,6 +55,7 @@ async def lifespan(app: FastAPI): yield await close_db_connection(app) + tenant_client = TenantAwareVedaCrudClient(pgstac_search_model=post_request_model) api = StacApi( @@ -167,7 +166,7 @@ async def tenant_viewer_page(request: Request, tenant: str): { "request": request, "endpoint": str(request.url).replace("/index.html", f"/{tenant}"), - "tenant": tenant + "tenant": tenant, }, media_type="text/html", ) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index bdbbc81c..317e5ddb 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -1,43 +1,107 @@ +""" +Tenant Client for Tenant Middleware +""" import logging from typing import Any, Dict, Optional, Union +from urllib.parse import urlparse -from fastapi import HTTPException, Request as FastAPIRequest -from stac_fastapi.types.stac import Item, ItemCollection, Collections, Collection, LandingPage +from fastapi import HTTPException +from fastapi import Request as FastAPIRequest +from stac_fastapi.types.stac import Collection, Item, ItemCollection, LandingPage from starlette.requests import Request -from urllib.parse import urlparse from .core import VedaCrudClient from .tenant_models import TenantValidationError logger = logging.getLogger(__name__) -class TenantAwareVedaCrudClient(VedaCrudClient): + +class TenantValidationMixin: + """Tenant Validation Mixin""" + + def validate_tenant_access( + self, + resource: Union[Dict[str, Any], Collection], + tenant: str, + resource_id: str = "", + ) -> None: + """Validate that a collection resource belongs to a tenant""" + resource_tenant = self._extract_tenant_from_resource(resource) + + if resource_tenant != tenant: + raise TenantValidationError( + resource_type="Collection" if "collection" in resource else "Item", + resource_id=resource_id, + tenant=tenant, + actual_tenant=resource_tenant, + ) + + def _extract_tenant_from_resource( + self, resource: Union[Dict[str, Any], Collection] + ) -> Optional[str]: + return resource.get("properties", {}).get("tenant") + + +class TenantAwareVedaCrudClient(VedaCrudClient, TenantValidationMixin): + """Tenant Aware VEDA Crud Client""" + def __init__(self, *args, **kwargs): + """Initializes tenant-aware VEDA CRUD client by extending + the base VEDA CRUD client with tenant functionality such as filtering, + validation, and customized landing page links. + + Args: + *args: positional args passed to parent VedaCrudClient + **kwargs: keyword args passed to parent VedaCrudClient such as + pgstac_search_model + + """ super().__init__(*args, **kwargs) def get_tenant_from_request(self, request: Request) -> Optional[str]: - if hasattr(request, 'path_params') and 'tenant' in request.path_params: - return request.path_params['tenant'] + """Gets tenant string from request + + Args: + request: Incoming request + + Returns: + tenant, if there is one. None otherwise. + + """ + if hasattr(request, "path_params") and "tenant" in request.path_params: + return request.path_params["tenant"] return None async def get_tenant_collections( - self, - request: FastAPIRequest, - tenant: Optional[str] = None, - **kwargs + self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs ) -> Dict[str, Any]: + """Gets collections belonging to a tenant + + Args: + request: Incoming request + tenant: Tenant ID + + Returns: + Collections belonging to tenant + + """ collections = await super().all_collections(request, **kwargs) collections_dict = collections - if tenant and isinstance(collections_dict, dict) and 'collections' in collections_dict: + if ( + tenant + and isinstance(collections_dict, dict) + and "collections" in collections_dict + ): filtered_collections = [ - col for col in collections_dict['collections'] - if col.get('properties', {}).get('tenant') == tenant + col + for col in collections_dict["collections"] + if col.get("properties", {}).get("tenant") == tenant ] - collections_dict['collections'] = filtered_collections - if 'numberReturned' in collections_dict: - collections_dict['numberReturned'] = len(filtered_collections) + collections_dict["collections"] = filtered_collections + if "numberReturned" in collections_dict: + collections_dict["numberReturned"] = len(filtered_collections) return collections_dict @@ -46,8 +110,10 @@ async def get_collection( collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, - **kwargs + **kwargs, ) -> Collection: + """Get a specific collection belonging to a tenant by collection, tenant IDs""" + collection = await super().get_collection(collection_id, request, **kwargs) if tenant and collection: @@ -64,12 +130,13 @@ async def item_collection( token: Optional[str] = None, **kwargs, ) -> ItemCollection: + """Get all items from collection using collection ID and tenant ID""" if tenant: collection = await super().get_collection(collection_id, request, **kwargs) if not collection: raise HTTPException( status_code=404, - detail=f"Collection {collection_id} not found for tenant {tenant}" + detail=f"Collection {collection_id} not found for tenant {tenant}", ) self.validate_tenant_access(collection, tenant, collection_id) @@ -78,7 +145,7 @@ async def item_collection( request=request, limit=limit, token=token, - **kwargs + **kwargs, ) async def get_item( @@ -87,14 +154,15 @@ async def get_item( collection_id: str, request: FastAPIRequest, tenant: Optional[str] = None, - **kwargs + **kwargs, ) -> Item: + """Get specific item from collection using collection ID and tenant ID""" if tenant: collection = await super().get_collection(collection_id, request, **kwargs) if not collection: raise HTTPException( status_code=404, - detail=f"Collection {collection_id} not found for tenant {tenant}" + detail=f"Collection {collection_id} not found for tenant {tenant}", ) self.validate_tenant_access(collection, tenant, collection_id) @@ -105,8 +173,20 @@ async def post_search( search_request, request: FastAPIRequest, tenant: Optional[str] = None, - **kwargs + **kwargs, ) -> ItemCollection: + """POST Search request with tenant filtering + + Args: + search_request: the search request parameters + request: the FastAPI request object + tenant: optional tenant identifier for filtering search + **kwargs: additional arguments to pass to the parent method + + Returns: + ItemCollection of the filtered search results + + """ result = await super().post_search(search_request, request, **kwargs) if tenant: @@ -120,6 +200,18 @@ async def get_search( tenant: Optional[str] = None, **kwargs, ) -> ItemCollection: + """GET Search request with tenant filtering + + Args: + search_request: the search request parameters + request: the FastAPI request object + tenant: optional tenant identifier for filtering search + **kwargs: additional arguments to pass to the parent method + + Returns: + ItemCollection of the filtered search results + + """ result = await super().get_search(request, **kwargs) if tenant: @@ -128,27 +220,42 @@ async def get_search( return result def _filter_search_results_by_tenant( - self, - result: ItemCollection, - tenant: str + self, result: ItemCollection, tenant: str ) -> ItemCollection: - if isinstance(result, dict) and 'features' in result: + """Internal function to filter search results by tenant + + Args: + result: ItemCollection to filter + tenant: Tenant identifier to filter on + + Returns: + Filtered ItemCollection + """ + if isinstance(result, dict) and "features" in result: filtered_features = [ - feature for feature in result['features'] - if feature.get('properties', {}).get('tenant') == tenant + feature + for feature in result["features"] + if feature.get("properties", {}).get("tenant") == tenant ] - result['features'] = filtered_features - if 'numberReturned' in result: - result['numberReturned'] = len(filtered_features) + result["features"] = filtered_features + if "numberReturned" in result: + result["numberReturned"] = len(filtered_features) return result async def landing_page( - self, - request: FastAPIRequest, - tenant: Optional[str] = None, - **kwargs + self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs ) -> LandingPage: + """Get or generate landing page if a tenant is provided + + Args: + request: Fast API request object + tenant: Optional tenant identifier + **kwargs: Optional key word args to pass to parent method + + Returns: + Landing Page, customized if tenant provided + """ landing_page = await super().landing_page(request=request, **kwargs) if tenant: @@ -156,32 +263,45 @@ async def landing_page( return landing_page - def _customize_landing_page_for_tenant(self, landing_page: LandingPage, tenant: str) -> LandingPage: - if 'title' in landing_page: - landing_page['title'] = f"{tenant.upper()} - {landing_page['title']}" + def _customize_landing_page_for_tenant( + self, landing_page: LandingPage, tenant: str + ) -> LandingPage: + """ + Customized landing page with tenant route path injected into url + """ + + if "title" in landing_page: + landing_page["title"] = f"{tenant.upper()} - {landing_page['title']}" - if 'links' in landing_page: - for link in landing_page['links']: + if "links" in landing_page: + for link in landing_page["links"]: logger.info("Inspecting links to inject tenant...") - if 'href' in link: - href = link['href'] + if "href" in link: + href = link["href"] - skip_links = ['self', 'root', 'service-desc', 'service-doc'] - if link.get('link') in skip_links: + skip_links = ["self", "root", "service-desc", "service-doc"] + if link.get("link") in skip_links: continue - if href.startswith('http'): + if href.startswith("http"): parsed = urlparse(href) - path_parts = parsed.path.split('/') + path_parts = parsed.path.split("/") # a URL should follow this structure scheme://netloc/path;parameters?query#fragment generally # source: https://docs.python.org/3/library/urllib.parse.html - if len(path_parts) >= 3 and path_parts[1] == 'api' and path_parts[2] == 'stac': + if ( + len(path_parts) >= 3 + and path_parts[1] == "api" + and path_parts[2] == "stac" + ): new_path_parts = path_parts[:3] + [tenant] + path_parts[3:] - new_path = '/'.join(new_path_parts) - link['href'] = f"{parsed.scheme}://{parsed.netloc}{new_path}" + new_path = "/".join(new_path_parts) + link[ + "href" + ] = f"{parsed.scheme}://{parsed.netloc}{new_path}" else: - if href.startswith('/api/stac'): - link['href'] = href.replace('/api/stac', f'/api/stac/{tenant}') - + if href.startswith("/api/stac"): + link["href"] = href.replace( + "/api/stac", f"/api/stac/{tenant}" + ) return landing_page diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py index 8b51f27f..c31dc73b 100644 --- a/stac_api/runtime/src/tenant_middleware.py +++ b/stac_api/runtime/src/tenant_middleware.py @@ -1,12 +1,15 @@ +""" Tenant Middleware for STAC API. Useful for extracting tenant information """ import logging from typing import Optional -from fastapi import Request, HTTPException + +from fastapi import HTTPException, Request from starlette.middleware.base import BaseHTTPMiddleware from .tenant_models import TenantContext, TenantValidationError logger = logging.getLogger(__name__) + class TenantMiddleware(BaseHTTPMiddleware): def __init__(self, app): super().__init__(app) @@ -15,10 +18,14 @@ async def dispatch(self, request: Request, call_next): try: tenant = self._extract_tenant(request) - tenant_context = TenantContext( - tenant_id=tenant, - request_id=request.headers.get("X-Correlation-ID"), - ) if tenant else None + tenant_context = ( + TenantContext( + tenant_id=tenant, + request_id=request.headers.get("X-Correlation-ID"), + ) + if tenant + else None + ) request.state.tenant_context = tenant_context @@ -28,8 +35,8 @@ async def dispatch(self, request: Request, call_next): extra={ "tenant": tenant, "method": request.method, - "path": request.url.path if tenant_context else None - } + "path": request.url.path if tenant_context else None, + }, ) response = await call_next(request) @@ -41,14 +48,14 @@ async def dispatch(self, request: Request, call_next): return response - except TenantValidtionError as e: + except TenantValidationError as e: logger.warning( f"Tenant validation failed: {e.detail}", extra={ - "tenant": getattr(e, 'tenant', None), - "resource_type": getattr(e, 'resource_type', None), - "resource_id": getattr(e, 'resource_id', None) - } + "tenant": getattr(e, "tenant", None), + "resource_type": getattr(e, "resource_type", None), + "resource_id": getattr(e, "resource_id", None), + }, ) raise HTTPException(status_code=404, detail=e.detail) @@ -58,8 +65,8 @@ async def dispatch(self, request: Request, call_next): def _extract_tenant(self, request: Request) -> Optional[str]: path = request.url.path - if path.startswith('/api/stac/'): - path_parts = path.replace('/api/stac/', '').split('/') + if path.startswith("/api/stac/"): + path_parts = path.replace("/api/stac/", "").split("/") if path_parts and path_parts[0]: return path_parts[0] - return None \ No newline at end of file + return None diff --git a/stac_api/runtime/src/tenant_models.py b/stac_api/runtime/src/tenant_models.py index e65ad353..2ad96301 100644 --- a/stac_api/runtime/src/tenant_models.py +++ b/stac_api/runtime/src/tenant_models.py @@ -1,12 +1,15 @@ from typing import Any, Dict, List, Optional + from pydantic import BaseModel, Field, field_validator + from fastapi import HTTPException + class TenantContext(BaseModel): tenant_id: str = Field(..., description="Tenant identifier") request_id: Optional[str] = Field(None, description="Request correlation ID") - @field_validator('tenant_id') + @field_validator("tenant_id") @classmethod def validate_tenant_id(cls, v): if not v or not v.strip(): @@ -20,7 +23,9 @@ class TenantSearchRequest(BaseModel): """Tenant-aware search request model.""" tenant: Optional[str] = Field(None, description="Tenant identifier") - collections: Optional[List[str]] = Field(None, description="Collection IDs to search") + collections: Optional[List[str]] = Field( + None, description="Collection IDs to search" + ) bbox: Optional[List[float]] = Field(None, description="Bounding box") datetime: Optional[str] = Field(None, description="Datetime range") limit: int = Field(10, description="Maximum number of results") @@ -35,30 +40,22 @@ def add_tenant_filter(self, tenant: str) -> None: return # Create tenant filter for properties.tenant - tenant_filter = { - "op": "=", - "args": [ - {"property": "tenant"}, - tenant - ] - } + tenant_filter = {"op": "=", "args": [{"property": "tenant"}, tenant]} # If there's already a filter, combine using AND if self.filter: - self.filter = { - "op": "and", - "args": [self.filter, tenant_filter] - } + self.filter = {"op": "and", "args": [self.filter, tenant_filter]} else: self.filter = tenant_filter + class TenantValidationError(HTTPException): def __init__( self, resource_type: str, resource_id: str, tenant: str, - actual_tenant: Optional[str] = None + actual_tenant: Optional[str] = None, ): self.resource_type = resource_type self.resource_id = resource_id @@ -69,4 +66,4 @@ def __init__( if actual_tenant: detail += f" (found tenant: {actual_tenant})" - super().__init__(status_code=404, detail=detail) \ No newline at end of file + super().__init__(status_code=404, detail=detail) diff --git a/stac_api/runtime/src/tenant_routes.py b/stac_api/runtime/src/tenant_routes.py index a13d8f45..3b519799 100644 --- a/stac_api/runtime/src/tenant_routes.py +++ b/stac_api/runtime/src/tenant_routes.py @@ -1,19 +1,23 @@ +""" Tenant Route Handler """ import json import logging from typing import Any, Dict, Optional -from fastapi import APIRouter, Request, Path, Query, HTTPException +from fastapi import APIRouter, HTTPException, Path, Query, Request from fastapi.responses import ORJSONResponse from stac_fastapi.types.stac import Item, ItemCollection from .tenant_client import TenantAwareVedaCrudClient from .tenant_models import TenantSearchRequest - logger = logging.getLogger(__name__) + class TenantRouteHandler: + """Route handler for tenant-aware STAC API endpoints""" + def __init__(self, client: TenantAwareVedaCrudClient): + """Initializes tenant-aware route handler""" self.client = client async def get_tenant_collections( @@ -21,10 +25,14 @@ async def get_tenant_collections( request: Request, tenant: str = Path(..., description="Tenant identifier"), ) -> Dict[str, Any]: + """Get all collections belonging to a tenant""" + logger.info(f"Getting collections for tenant: {tenant}") try: - collections = await self.client.get_tenant_collections(request, tenant=tenant) + collections = await self.client.get_tenant_collections( + request, tenant=tenant + ) return collections except Exception as e: logger.error(f"Error getting collections for tenant {tenant}: {str(e)}") @@ -36,6 +44,7 @@ async def get_tenant_collection( tenant: str = Path(..., description="Tenant identifier"), collection_id: str = Path(..., description="Collection identifier"), ) -> Dict: + """Get a specific collection belonging to a specific tenant""" logger.info(f"Getting collection {collection_id} for tenant: {tenant}") try: @@ -59,7 +68,10 @@ async def get_tenant_collection_items( limit: int = Query(10, description="Maximum number of items to return"), token: Optional[str] = Query(None, description="Pagination token"), ) -> ItemCollection: - logger.info(f"Getting items from collection {collection_id} for tenant: {tenant}") + """Get all items from a collection filtered by a specific tenant""" + logger.info( + f"Getting items from collection {collection_id} for tenant: {tenant}" + ) try: items = await self.client.item_collection( @@ -67,7 +79,7 @@ async def get_tenant_collection_items( request=request, tenant=tenant, limit=limit, - token=token + token=token, ) return items except HTTPException: @@ -85,7 +97,10 @@ async def get_tenant_item( collection_id: str = Path(..., description="Collection identifier"), item_id: str = Path(..., description="Item identifier"), ) -> Item: - logger.info(f"Getting item {item_id} from collection {collection_id} for tenant: {tenant}") + """Get a specific item for a tenant""" + logger.info( + f"Getting item {item_id} from collection {collection_id} for tenant: {tenant}" + ) try: item = await self.client.get_item( @@ -104,8 +119,12 @@ async def get_tenant_search( self, request: Request, tenant: str = Path(..., description="Tenant identifier"), - collections: Optional[str] = Query(None, description="Comma-separated list of collection IDs"), - ids: Optional[str] = Query(None, description="Comma-separated list of item IDs"), + collections: Optional[str] = Query( + None, description="Comma-separated list of collection IDs" + ), + ids: Optional[str] = Query( + None, description="Comma-separated list of item IDs" + ), bbox: Optional[str] = Query(None, description="Bounding box"), datetime: Optional[str] = Query(None, description="Datetime range"), limit: int = Query(10, description="Maximum number of results"), @@ -115,7 +134,7 @@ async def get_tenant_search( filter: Optional[str] = Query(None, description="CQL2 filter"), sortby: Optional[str] = Query(None, description="Sort parameters"), ) -> ItemCollection: - """Search items for a specific tenant using GET.""" + """Search items for a specific tenant using GET""" logger.info(f"GET search for tenant: {tenant}") try: @@ -141,7 +160,9 @@ async def get_tenant_search( return search_result except json.JSONDecodeError as e: logger.error(f"Invalid JSON in search parameters: {str(e)}") - raise HTTPException(status_code=400, detail="Invalid JSON in search parameters") + raise HTTPException( + status_code=400, detail="Invalid JSON in search parameters" + ) except HTTPException: raise except Exception as e: @@ -154,7 +175,7 @@ async def post_tenant_search( request: Request, tenant: str = Path(..., description="Tenant identifier"), ) -> ItemCollection: - """Search items for a specific tenant using POST.""" + """Search items for a specific tenant using POST""" logger.info(f"POST search for tenant: {tenant}") try: @@ -176,7 +197,8 @@ async def get_tenant_landing_page( request: Request, tenant: str = Path(..., description="Tenant identifier"), ) -> ORJSONResponse: - """Get landing page for a specific tenant.""" + """Get landing page for a specific tenant""" + logger.info(f"Getting landing page for tenant: {tenant}") try: @@ -190,7 +212,7 @@ async def get_tenant_landing_page( def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: - """Create tenant-specific router """ + """Create tenant-specific router""" router = APIRouter(redirect_slashes=True) handler = TenantRouteHandler(client) @@ -202,7 +224,7 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: handler.get_tenant_collections, methods=["GET"], summary="Get collections for tenant", - description="Retrieve all collections for a specific tenant" + description="Retrieve all collections for a specific tenant", ) router.add_api_route( @@ -210,7 +232,7 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: handler.get_tenant_collection, methods=["GET"], summary="Get tenant collection", - description="Retrieve a specific collection for a tenant" + description="Retrieve a specific collection for a tenant", ) router.add_api_route( @@ -218,7 +240,7 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: handler.get_tenant_collection_items, methods=["GET"], summary="Get tenant collection items", - description="Retrieve items from a collection for a tenant" + description="Retrieve items from a collection for a tenant", ) router.add_api_route( @@ -226,7 +248,7 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: handler.get_tenant_item, methods=["GET"], summary="Get tenant item", - description="Retrieve a specific item for a tenant" + description="Retrieve a specific item for a tenant", ) # Search endpoints @@ -235,7 +257,7 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: handler.get_tenant_search, methods=["GET"], summary="Search tenant items (GET)", - description="Search items for a tenant using GET method" + description="Search items for a tenant using GET method", ) router.add_api_route( @@ -243,7 +265,7 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: handler.post_tenant_search, methods=["POST"], summary="Search tenant items (POST)", - description="Search items for a tenant using POST method" + description="Search items for a tenant using POST method", ) router.add_api_route( @@ -251,7 +273,7 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: handler.get_tenant_landing_page, methods=["GET"], summary="Get tenant landing page", - description="Retrieve landing page for a specific tenant" + description="Retrieve landing page for a specific tenant", ) logger.info(f"Created tenant router with {len(router.routes)} routes") diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index ede8bad1..6aa2bc89 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -75,6 +75,51 @@ }, } +VALID_COLLECTION_WITH_TENANT = { + "id": "campfire-lst-day-diff", + "type": "Collection", + "links": [], + "title": "Camp Fire Domain: MODIS LST Day Difference", + "extent": { + "spatial": { + "bbox": [ + [ + -121.78460307847297, + 39.59483467430542, + -121.35341172149457, + 39.89994756059251, + ] + ] + }, + "temporal": { + "interval": [["2015-01-01T00:00:00+00:00", "2022-01-01T00:00:00+00:00"]] + }, + }, + "license": "CC0-1.0", + "providers": [ + { + "url": "https://www.earthdata.nasa.gov/dashboard/", + "name": "NASA VEDA", + "roles": ["host"], + } + ], + "summaries": {"datetime": ["2015-01-01T00:00:00Z"]}, + "properties": {"tenant": "fake-tenant"}, + "description": "MODIS WSA Albedo difference from a three-year average of 2015 to 2018 subtracted from a three-year average of 2019-2022. These tri-annual averages represent periods before and after the fire.", + "item_assets": { + "cog_default": { + "type": "image/tiff; application=geotiff; profile=cloud-optimized", + "roles": ["data", "layer"], + "title": "Default COG Layer", + "description": "Cloud optimized default layer to display on map", + } + }, + "stac_version": "1.0.0", + "stac_extensions": [ + "https://stac-extensions.github.io/item-assets/v1.0.0/schema.json" + ], +} + VALID_ITEM = { "id": "OMI_trno2_0.10x0.10_2023_Col3_V4", "bbox": [-180.0, -90.0, 180.0, 90.0], @@ -334,6 +379,17 @@ def valid_stac_collection(): return VALID_COLLECTION +@pytest.fixture +def valid_stac_collection_with_tenant(): + """ + Fixture providing a valid STAC collection with tenant for testing. + + Returns: + dict: A valid STAC collection with tenant. + """ + return VALID_COLLECTION_WITH_TENANT + + @pytest.fixture def invalid_stac_collection(): """ @@ -380,11 +436,17 @@ async def collection_in_db(api_client, valid_stac_collection): the collection ID. """ # Create the collection - response = await api_client.post("/collections", json=valid_stac_collection) + collection_response = await api_client.post( + "/collections", json=valid_stac_collection + ) + collection_with_tenant_response = await api_client.post( + "/collections", json=valid_stac_collection_with_tenant + ) # Ensure the setup was successful before the test proceeds # The setup is successful if the collection was created (201) or if it # already existed (409). Any other status code is a failure. - assert response.status_code in [201, 409] + assert collection_response.status_code in [201, 409] + assert collection_with_tenant_response.status_code in [201, 409] - yield valid_stac_collection["id"] + yield [valid_stac_collection["id"], valid_stac_collection_with_tenant["id"]] diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index 8ff18e81..181872b1 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -21,6 +21,7 @@ collections_endpoint = "/collections" items_endpoint = "/collections/{}/items" bulk_endpoint = "/collections/{}/bulk_items" +tenant_collections_endpoint = "/fake-tenant/collections" class TestList: @@ -120,7 +121,7 @@ async def test_get_collection_by_id(self, api_client, collection_in_db): Test searching for a specific collection by its ID. """ # The `collection_in_db` fixture ensures the collection exists and provides its ID. - collection_id = collection_in_db + collection_id = collection_in_db[0] # Perform a GET request to the /collections endpoint with an "ids" query response = await api_client.get( @@ -141,7 +142,7 @@ async def test_collection_freetext_search_by_title( """ # The `collection_in_db` fixture ensures the collection exists. - collection_id = collection_in_db + collection_id = collection_in_db[0] # Use a unique word from the collection's title for the query. search_term = "precipitation" @@ -156,3 +157,20 @@ async def test_collection_freetext_search_by_title( returned_ids = [col["id"] for col in response_data["collections"]] assert collection_id in returned_ids + + async def test_get_collections_by_tenant(self, api_client, collection_in_db): + """ + Test searching for a specific collection by its ID. + """ + collection_id = collection_in_db[1] + + # Perform a GET request to the /collections endpoint with a tenant + response = await api_client.get( + tenant_collections_endpoint, + ) + + assert response.status_code == 200 + + response_data = response.json() + + assert response_data["collections"][0]["id"] == collection_id From 7f1da50db89fdad537168c19813721980afa2eaf Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:47:38 -0700 Subject: [PATCH 13/33] fix: add more docstrings --- stac_api/runtime/src/tenant_middleware.py | 16 ++++++++++++++++ stac_api/runtime/src/tenant_models.py | 11 +++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py index c31dc73b..42009de6 100644 --- a/stac_api/runtime/src/tenant_middleware.py +++ b/stac_api/runtime/src/tenant_middleware.py @@ -11,10 +11,25 @@ class TenantMiddleware(BaseHTTPMiddleware): + """Middleware for tenant-aware STAC API request processing. + + This middleware extracts the tenant identifier from the URL path and creates a context + for downstream processing. It also handles valiadtion errors. + + It will process requests by: + - extracting the tenant from the URL path (/api/stac/{tenant}/...) + - creating a TenantContext with the tenant ID and correlation ID + - handle validation errors + + """ + def __init__(self, app): + """Initializes the tenant middleware""" super().__init__(app) async def dispatch(self, request: Request, call_next): + """Processes incoming requests and extracts the tenant identifier from the URL""" + try: tenant = self._extract_tenant(request) @@ -64,6 +79,7 @@ async def dispatch(self, request: Request, call_next): raise HTTPException(status_code=500, detail="Internal server error") def _extract_tenant(self, request: Request) -> Optional[str]: + """Extracts the tenant identifier from the URL""" path = request.url.path if path.startswith("/api/stac/"): path_parts = path.replace("/api/stac/", "").split("/") diff --git a/stac_api/runtime/src/tenant_models.py b/stac_api/runtime/src/tenant_models.py index 2ad96301..43570455 100644 --- a/stac_api/runtime/src/tenant_models.py +++ b/stac_api/runtime/src/tenant_models.py @@ -1,3 +1,4 @@ +""" Tenant Models for STAC API """ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, field_validator @@ -6,12 +7,15 @@ class TenantContext(BaseModel): + """Context information for tenant-aware request processing""" + tenant_id: str = Field(..., description="Tenant identifier") request_id: Optional[str] = Field(None, description="Request correlation ID") @field_validator("tenant_id") @classmethod def validate_tenant_id(cls, v): + """Validates the tenant ID and also normalizes it to lowercase and trims the whitespace""" if not v or not v.strip(): raise ValueError("Tenant ID cannot be empty") if len(v) > 100: @@ -20,7 +24,7 @@ def validate_tenant_id(cls, v): class TenantSearchRequest(BaseModel): - """Tenant-aware search request model.""" + """Tenant-aware search request model""" tenant: Optional[str] = Field(None, description="Tenant identifier") collections: Optional[List[str]] = Field( @@ -35,7 +39,7 @@ class TenantSearchRequest(BaseModel): conf: Optional[Dict] = None def add_tenant_filter(self, tenant: str) -> None: - """Add tenant filter to the search request.""" + """Add tenant filter to the search request""" if not tenant: return @@ -50,6 +54,8 @@ def add_tenant_filter(self, tenant: str) -> None: class TenantValidationError(HTTPException): + """Exception that can be used to raise tenant validation failures""" + def __init__( self, resource_type: str, @@ -57,6 +63,7 @@ def __init__( tenant: str, actual_tenant: Optional[str] = None, ): + """Initiailizes tenant validation error""" self.resource_type = resource_type self.resource_id = resource_id self.tenant = tenant From b805636de97610bd8d09481b15a25d5a58efc1b1 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:49:12 -0700 Subject: [PATCH 14/33] fix: add logging to landing page for debugging --- stac_api/runtime/src/tenant_client.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index 317e5ddb..ff5428db 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -256,11 +256,30 @@ async def landing_page( Returns: Landing Page, customized if tenant provided """ + tenant_context = getattr(request.state, 'tenant_context'', None) + + logger.info( + f"Landing page requested for tenant: {tenant}", + extra={ + "tenant_id": tenant, + "request_id": tenant_context.request_id if tenant_context else None, + "endpoint": "landing_page", + } + ) landing_page = await super().landing_page(request=request, **kwargs) if tenant: landing_page = self._customize_landing_page_for_tenant(landing_page, tenant) + logger.info( + f"Landing page customized for tenant: {tenant}", + extra={ + "tenant_id": tenant, + "request_id": tenant_context.request_id if tenant_context else None, + "links_modified": len(landing_page.get("links", [])), + } + ) + return landing_page def _customize_landing_page_for_tenant( From a648cd9ede846b4be666b6cfe41e2ca7f7708c33 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:53:04 -0700 Subject: [PATCH 15/33] fix: formatting and typos --- stac_api/runtime/src/tenant_client.py | 31 ++++++++++++++------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index ff5428db..46fcd8c3 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -256,29 +256,30 @@ async def landing_page( Returns: Landing Page, customized if tenant provided """ - tenant_context = getattr(request.state, 'tenant_context'', None) + tenant_context = getattr(request.state, "tenant_context", None) logger.info( - f"Landing page requested for tenant: {tenant}", - extra={ - "tenant_id": tenant, - "request_id": tenant_context.request_id if tenant_context else None, - "endpoint": "landing_page", - } - ) + f"Landing page requested for tenant: {tenant}", + extra={ + "tenant_id": tenant, + "request_id": tenant_context.request_id if tenant_context else None, + "endpoint": "landing_page", + }, + ) + landing_page = await super().landing_page(request=request, **kwargs) if tenant: landing_page = self._customize_landing_page_for_tenant(landing_page, tenant) logger.info( - f"Landing page customized for tenant: {tenant}", - extra={ - "tenant_id": tenant, - "request_id": tenant_context.request_id if tenant_context else None, - "links_modified": len(landing_page.get("links", [])), - } - ) + f"Landing page customized for tenant: {tenant}", + extra={ + "tenant_id": tenant, + "request_id": tenant_context.request_id if tenant_context else None, + "links_modified": len(landing_page.get("links", [])), + }, + ) return landing_page From 6d63e54c81d330e2d337f56f80055efc8645e647 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:35:02 -0700 Subject: [PATCH 16/33] fix: add logging to landing page fn --- stac_api/runtime/src/tenant_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index 46fcd8c3..406bad95 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -299,8 +299,9 @@ def _customize_landing_page_for_tenant( if "href" in link: href = link["href"] - skip_links = ["self", "root", "service-desc", "service-doc"] - if link.get("link") in skip_links: + skip_rels = ["self", "root", "service-desc", "service-doc", "conformance"] + if link.get("rels") in skip_links: + logger.info(f"Skipping link with rel {link.get('rel')}") continue if href.startswith("http"): From 83ae864add1afb13e9c3202ce843e6cb3c7e9792 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:37:45 -0700 Subject: [PATCH 17/33] fix: update old var name --- stac_api/runtime/src/tenant_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index 406bad95..86b199b7 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -300,7 +300,7 @@ def _customize_landing_page_for_tenant( href = link["href"] skip_rels = ["self", "root", "service-desc", "service-doc", "conformance"] - if link.get("rels") in skip_links: + if link.get("rels") in skip_rels: logger.info(f"Skipping link with rel {link.get('rel')}") continue From 0d0d77fc61ab1fdd6a67e78d71be4deb3aac7b91 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:47:35 -0700 Subject: [PATCH 18/33] feat: add more unit tests --- stac_api/runtime/src/tenant_client.py | 8 ++- stac_api/runtime/tests/test_extensions.py | 66 +++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index 86b199b7..4718245e 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -299,7 +299,13 @@ def _customize_landing_page_for_tenant( if "href" in link: href = link["href"] - skip_rels = ["self", "root", "service-desc", "service-doc", "conformance"] + skip_rels = [ + "self", + "root", + "service-desc", + "service-doc", + "conformance", + ] if link.get("rels") in skip_rels: logger.info(f"Skipping link with rel {link.get('rel')}") continue diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index 181872b1..4b561472 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -174,3 +174,69 @@ async def test_get_collections_by_tenant(self, api_client, collection_in_db): response_data = response.json() assert response_data["collections"][0]["id"] == collection_id + + async def test_tenant_landing_page_customization(self, api_client): + """ + Test that tenant landing page is properly customized for tenant + """ + response = await api_client.get("/fake-tenant/") + assert response.status_code == 200 + + landing_page = response.json() + assert "FAKE-TENANT" in landing_page["title"] + + for link in landing_page.get("links", []): + if link.get("rel") not in [ + "self", + "root", + "service-desc", + "service-doc", + "conformance", + ]: + assert "/fake-tenant/" in link["href"] + + async def test_get_collection_by_id_with_tenant(self, api_client, collection_in_db): + """ + Test searching for a specific collection by its ID and tenant + """ + # The `collection_in_db` fixture ensures the collection exists and provides its ID. + collection_id = collection_in_db[1] + + # Perform a GET request to the /fake-tenant/collections endpoint with an "ids" query + response = await api_client.get( + tenant_collections_endpoint, params={"ids": collection_id} + ) + + assert response.status_code == 200 + + response_data = response.json() + + assert response_data["collections"][0]["id"] == collection_id + + async def test_tenant_validation_error(self, api_client, collection_in_db): + """ + Test that accessing wrong tenant's collection returns 404 + """ + collection_id = collection_in_db[1] + + # Try to access unexistent tenant for collection that exists in fake-tenant + response = await api_client.get(f"/fake-tenant-2/collections/{collection_id}") + assert response.status_code == 404 + + async def test_invalid_tenant_format(self, api_client): + """ + Test handling of invalid tenant formats + """ + + response = await api_client.get("/invalid-tenant-format/collections") + + assert response.status_code in [400, 404] + + async def test_missing_tenant_parameter(self, api_client): + """ + Test behavior when tenant parameter is not supplied in route path + """ + + response = await api_client.get("/collections") + # Should return all collections (no tenant filtering) + assert response.status_code == 200 From bcddd4fdd6a52a44070b9d8e4d2e7b03aa04010d Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 16 Sep 2025 18:45:51 -0700 Subject: [PATCH 19/33] fix: fix typo in tenant_client rel retrieval, update tests --- stac_api/runtime/src/tenant_client.py | 2 +- stac_api/runtime/tests/conftest.py | 7 ++-- stac_api/runtime/tests/test_extensions.py | 47 ++++++++++++++++++----- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index 4718245e..7bcf1b6a 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -306,7 +306,7 @@ def _customize_landing_page_for_tenant( "service-doc", "conformance", ] - if link.get("rels") in skip_rels: + if link.get("rel") in skip_rels: logger.info(f"Skipping link with rel {link.get('rel')}") continue diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index 6aa2bc89..04c78a55 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch import pytest +import pytest_asyncio from httpx import ASGITransport, AsyncClient from stac_fastapi.pgstac.db import close_db_connection, connect_to_db @@ -318,7 +319,7 @@ def mock_auth(): yield mock_instance -@pytest.fixture +@pytest_asyncio.fixture async def app(): """ Fixture to initialize the FastAPI application. @@ -339,7 +340,7 @@ async def app(): await close_db_connection(app) -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def api_client(app): """ Fixture to initialize the API client for making requests. @@ -427,7 +428,7 @@ def invalid_stac_item(): return invalid_item -@pytest.fixture +@pytest_asyncio.fixture async def collection_in_db(api_client, valid_stac_collection): """ Fixture to ensure a valid STAC collection exists in the database. diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index 4b561472..5e187b46 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -16,7 +16,7 @@ - /Collections search by id and free text search """ - +import pytest collections_endpoint = "/collections" items_endpoint = "/collections/{}/items" @@ -34,6 +34,7 @@ class TestList: necessary data. """ + @pytest.mark.asyncio async def test_post_invalid_collection(self, api_client, invalid_stac_collection): """ Test the API's response to posting an invalid STAC collection. @@ -47,6 +48,7 @@ async def test_post_invalid_collection(self, api_client, invalid_stac_collection assert response.json()["detail"] == "Validation Error" assert response.status_code == 422 + @pytest.mark.asyncio async def test_post_valid_collection(self, api_client, valid_stac_collection): """ Test the API's response to posting a valid STAC collection. @@ -58,6 +60,7 @@ async def test_post_valid_collection(self, api_client, valid_stac_collection): ) assert response.status_code == 201 + @pytest.mark.asyncio async def test_post_invalid_item(self, api_client, invalid_stac_item): """ Test the API's response to posting an invalid STAC item. @@ -72,6 +75,7 @@ async def test_post_invalid_item(self, api_client, invalid_stac_item): assert response.json()["detail"] == "Validation Error" assert response.status_code == 422 + @pytest.mark.asyncio async def test_post_valid_item(self, api_client, valid_stac_item, collection_in_db): """ Test the API's response to posting a valid STAC item. @@ -84,6 +88,7 @@ async def test_post_valid_item(self, api_client, valid_stac_item, collection_in_ ) assert response.status_code == 201 + @pytest.mark.asyncio async def test_post_invalid_bulk_items(self, api_client, invalid_stac_item): """ Test the API's response to posting invalid bulk STAC items. @@ -99,6 +104,7 @@ async def test_post_invalid_bulk_items(self, api_client, invalid_stac_item): ) assert response.status_code == 422 + @pytest.mark.asyncio async def test_post_valid_bulk_items( self, api_client, valid_stac_item, collection_in_db ): @@ -116,6 +122,7 @@ async def test_post_valid_bulk_items( ) assert response.status_code == 200 + @pytest.mark.asyncio async def test_get_collection_by_id(self, api_client, collection_in_db): """ Test searching for a specific collection by its ID. @@ -134,6 +141,7 @@ async def test_get_collection_by_id(self, api_client, collection_in_db): assert response_data["collections"][0]["id"] == collection_id + @pytest.mark.asyncio async def test_collection_freetext_search_by_title( self, api_client, collection_in_db ): @@ -158,6 +166,7 @@ async def test_collection_freetext_search_by_title( returned_ids = [col["id"] for col in response_data["collections"]] assert collection_id in returned_ids + @pytest.mark.asyncio async def test_get_collections_by_tenant(self, api_client, collection_in_db): """ Test searching for a specific collection by its ID. @@ -175,6 +184,7 @@ async def test_get_collections_by_tenant(self, api_client, collection_in_db): assert response_data["collections"][0]["id"] == collection_id + @pytest.mark.asyncio async def test_tenant_landing_page_customization(self, api_client): """ Test that tenant landing page is properly customized for tenant @@ -185,16 +195,30 @@ async def test_tenant_landing_page_customization(self, api_client): landing_page = response.json() assert "FAKE-TENANT" in landing_page["title"] + excluded_rels = [ + "self", + "root", + "service-desc", + "service-doc", + "conformance", + ] for link in landing_page.get("links", []): - if link.get("rel") not in [ - "self", - "root", - "service-desc", - "service-doc", - "conformance", - ]: - assert "/fake-tenant/" in link["href"] - + rel = link.get("rel") + href = link.get("href", "") + + if rel in excluded_rels: + assert ( + "/fake-tenant/" not in href + ), f"Excluded rel '{rel}' incorrectly contains tenant: {href}" + print(f"Excluded rel '{rel}' correctly has no tenant: {href}") + else: + if href.startswith("/api/stac") or "api/stac" in href: + assert ( + "/fake-tenant/" in href + ), f"Included rel '{rel}' does not have tenant: {href}" + print(f"Included rel '{rel}' correctly has tenant: {href}") + + @pytest.mark.asyncio async def test_get_collection_by_id_with_tenant(self, api_client, collection_in_db): """ Test searching for a specific collection by its ID and tenant @@ -213,6 +237,7 @@ async def test_get_collection_by_id_with_tenant(self, api_client, collection_in_ assert response_data["collections"][0]["id"] == collection_id + @pytest.mark.asyncio async def test_tenant_validation_error(self, api_client, collection_in_db): """ Test that accessing wrong tenant's collection returns 404 @@ -223,6 +248,7 @@ async def test_tenant_validation_error(self, api_client, collection_in_db): response = await api_client.get(f"/fake-tenant-2/collections/{collection_id}") assert response.status_code == 404 + @pytest.mark.asyncio async def test_invalid_tenant_format(self, api_client): """ Test handling of invalid tenant formats @@ -232,6 +258,7 @@ async def test_invalid_tenant_format(self, api_client): assert response.status_code in [400, 404] + @pytest.mark.asyncio async def test_missing_tenant_parameter(self, api_client): """ Test behavior when tenant parameter is not supplied in route path From d73667138c7dda1d89b07452e2acbae59305637b Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 08:58:18 -0700 Subject: [PATCH 20/33] fix: update landing page logic to exclude queryables --- stac_api/runtime/src/tenant_client.py | 4 ++-- stac_api/runtime/tests/test_extensions.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py index 7bcf1b6a..19a7fff0 100644 --- a/stac_api/runtime/src/tenant_client.py +++ b/stac_api/runtime/src/tenant_client.py @@ -298,6 +298,7 @@ def _customize_landing_page_for_tenant( logger.info("Inspecting links to inject tenant...") if "href" in link: href = link["href"] + rel = link.get("rel") skip_rels = [ "self", @@ -306,8 +307,7 @@ def _customize_landing_page_for_tenant( "service-doc", "conformance", ] - if link.get("rel") in skip_rels: - logger.info(f"Skipping link with rel {link.get('rel')}") + if rel in skip_rels or "queryables" in rel: continue if href.startswith("http"): diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index 5e187b46..b301aa57 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -201,6 +201,7 @@ async def test_tenant_landing_page_customization(self, api_client): "service-desc", "service-doc", "conformance", + "queryables", ] for link in landing_page.get("links", []): rel = link.get("rel") From 8eb8573834bfaf7f112714592164306f28078212 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:15:47 -0700 Subject: [PATCH 21/33] fix: add valid_stac_collection_with_tenant to params --- stac_api/runtime/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index 04c78a55..c4108e92 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -429,7 +429,7 @@ def invalid_stac_item(): @pytest_asyncio.fixture -async def collection_in_db(api_client, valid_stac_collection): +async def collection_in_db(api_client, valid_stac_collection, valid_stac_collection_with_tenant): """ Fixture to ensure a valid STAC collection exists in the database. From c79ab59deca82740b6233588000de270d0335b2d Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:28:42 -0700 Subject: [PATCH 22/33] fix: lint errors --- stac_api/runtime/tests/conftest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index c4108e92..657f2486 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -429,7 +429,9 @@ def invalid_stac_item(): @pytest_asyncio.fixture -async def collection_in_db(api_client, valid_stac_collection, valid_stac_collection_with_tenant): +async def collection_in_db( + api_client, valid_stac_collection, valid_stac_collection_with_tenant +): """ Fixture to ensure a valid STAC collection exists in the database. From a08e553ca31eb327cefda31236a2537537243f07 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:49:49 -0700 Subject: [PATCH 23/33] fix: update collection_in_db to yield dictionary --- stac_api/runtime/tests/conftest.py | 5 ++++- stac_api/runtime/tests/test_extensions.py | 10 +++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index 657f2486..74be872e 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -452,4 +452,7 @@ async def collection_in_db( assert collection_response.status_code in [201, 409] assert collection_with_tenant_response.status_code in [201, 409] - yield [valid_stac_collection["id"], valid_stac_collection_with_tenant["id"]] + yield { + "regular_collection": valid_stac_collection["id"], + "tenant_collection": valid_stac_collection_with_tenant["id"], + } diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index b301aa57..dbcb7552 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -128,7 +128,7 @@ async def test_get_collection_by_id(self, api_client, collection_in_db): Test searching for a specific collection by its ID. """ # The `collection_in_db` fixture ensures the collection exists and provides its ID. - collection_id = collection_in_db[0] + collection_id = collection_in_db["regular_collection"] # Perform a GET request to the /collections endpoint with an "ids" query response = await api_client.get( @@ -150,7 +150,7 @@ async def test_collection_freetext_search_by_title( """ # The `collection_in_db` fixture ensures the collection exists. - collection_id = collection_in_db[0] + collection_id = collection_in_db["regular_collection"] # Use a unique word from the collection's title for the query. search_term = "precipitation" @@ -171,7 +171,7 @@ async def test_get_collections_by_tenant(self, api_client, collection_in_db): """ Test searching for a specific collection by its ID. """ - collection_id = collection_in_db[1] + collection_id = collection_in_db["tenant_collection"] # Perform a GET request to the /collections endpoint with a tenant response = await api_client.get( @@ -225,7 +225,7 @@ async def test_get_collection_by_id_with_tenant(self, api_client, collection_in_ Test searching for a specific collection by its ID and tenant """ # The `collection_in_db` fixture ensures the collection exists and provides its ID. - collection_id = collection_in_db[1] + collection_id = collection_in_db["regular_collection"] # Perform a GET request to the /fake-tenant/collections endpoint with an "ids" query response = await api_client.get( @@ -243,7 +243,7 @@ async def test_tenant_validation_error(self, api_client, collection_in_db): """ Test that accessing wrong tenant's collection returns 404 """ - collection_id = collection_in_db[1] + collection_id = collection_in_db["tenant_collection"] # Try to access unexistent tenant for collection that exists in fake-tenant response = await api_client.get(f"/fake-tenant-2/collections/{collection_id}") From 6bdfc18e49301f5dcdec7c62f823b84e72492f90 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:39:12 -0700 Subject: [PATCH 24/33] fix: update tests to accomodate for additional tenant collection --- stac_api/runtime/tests/test_extensions.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index dbcb7552..196fd7c6 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -82,7 +82,10 @@ async def test_post_valid_item(self, api_client, valid_stac_item, collection_in_ Asserts that the response status code is 200. """ - collection_id = valid_stac_item["collection"] + collection_id = collection_in_db["regular_collection"] + item_data = valid_stac_item.copy() + item_data["collection"] = collection_id + response = await api_client.post( items_endpoint.format(collection_id), json=valid_stac_item ) @@ -225,7 +228,7 @@ async def test_get_collection_by_id_with_tenant(self, api_client, collection_in_ Test searching for a specific collection by its ID and tenant """ # The `collection_in_db` fixture ensures the collection exists and provides its ID. - collection_id = collection_in_db["regular_collection"] + collection_id = collection_in_db["tenant_collection"] # Perform a GET request to the /fake-tenant/collections endpoint with an "ids" query response = await api_client.get( @@ -255,9 +258,13 @@ async def test_invalid_tenant_format(self, api_client): Test handling of invalid tenant formats """ + # Non existent tenant should just show no collections response = await api_client.get("/invalid-tenant-format/collections") - assert response.status_code in [400, 404] + assert response.status_code in [200, 404] + if response.status_code == 200: + response_data = response.json() + assert response_data["collections"] == [] @pytest.mark.asyncio async def test_missing_tenant_parameter(self, api_client): From 76c2f47b2b6c579c0302a8ae9cc0605f73f6b391 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:19:40 -0700 Subject: [PATCH 25/33] fix: update exception handling in middleware --- stac_api/runtime/src/tenant_middleware.py | 2 +- stac_api/runtime/tests/test_extensions.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py index 42009de6..be4d40f5 100644 --- a/stac_api/runtime/src/tenant_middleware.py +++ b/stac_api/runtime/src/tenant_middleware.py @@ -76,7 +76,7 @@ async def dispatch(self, request: Request, call_next): except Exception as e: logger.error(f"Tenant middleware error: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") + raise def _extract_tenant(self, request: Request) -> Optional[str]: """Extracts the tenant identifier from the URL""" diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index 196fd7c6..4782eb13 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -132,6 +132,7 @@ async def test_get_collection_by_id(self, api_client, collection_in_db): """ # The `collection_in_db` fixture ensures the collection exists and provides its ID. collection_id = collection_in_db["regular_collection"] + print(f"collection_id in test_get_collection_by_id is {collection_id}") # Perform a GET request to the /collections endpoint with an "ids" query response = await api_client.get( From 07e1115dbba2bae88004190ce7a66cdb7c3b1281 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:32:56 -0700 Subject: [PATCH 26/33] fix: update tests, try disabling ssl verifications --- stac_api/runtime/tests/conftest.py | 3 +++ stac_api/runtime/tests/test_extensions.py | 8 ++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index 74be872e..e7c1cb0e 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -286,6 +286,9 @@ def test_environ(): os.environ["POSTGRES_HOST_WRITER"] = "0.0.0.0" os.environ["POSTGRES_PORT"] = "5439" + import ssl + ssl._create_default_https_context = ssl._create_unverified_context + def override_validated_token(): """ diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index 4782eb13..f8b0d880 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -132,18 +132,14 @@ async def test_get_collection_by_id(self, api_client, collection_in_db): """ # The `collection_in_db` fixture ensures the collection exists and provides its ID. collection_id = collection_in_db["regular_collection"] - print(f"collection_id in test_get_collection_by_id is {collection_id}") - # Perform a GET request to the /collections endpoint with an "ids" query - response = await api_client.get( - collections_endpoint, params={"ids": collection_id} - ) + response = await api_client.get(f"{collections_endpoint}/{collection_id}") assert response.status_code == 200 response_data = response.json() - assert response_data["collections"][0]["id"] == collection_id + assert response_data["id"] == collection_id @pytest.mark.asyncio async def test_collection_freetext_search_by_title( From 68d2fbff0d192b2319f3aa3c65305f1fe7c38d59 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:43:37 -0700 Subject: [PATCH 27/33] fix: formatting --- stac_api/runtime/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index e7c1cb0e..641de3f2 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -287,6 +287,7 @@ def test_environ(): os.environ["POSTGRES_PORT"] = "5439" import ssl + ssl._create_default_https_context = ssl._create_unverified_context From 3595fc52a37c79860394e6998e937f0fdba3b14b Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 13:19:45 -0700 Subject: [PATCH 28/33] fix: try removing proj:projjson in testing --- stac_api/runtime/src/validation.py | 4 +++ stac_api/runtime/tests/conftest.py | 35 ----------------------- stac_api/runtime/tests/test_extensions.py | 9 ++++-- 3 files changed, 10 insertions(+), 38 deletions(-) diff --git a/stac_api/runtime/src/validation.py b/stac_api/runtime/src/validation.py index b50d5c3a..79990880 100644 --- a/stac_api/runtime/src/validation.py +++ b/stac_api/runtime/src/validation.py @@ -2,6 +2,7 @@ import json import re +import ssl from typing import Dict from pydantic import BaseModel, Field @@ -13,6 +14,9 @@ from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware +# Disable SSL verification for external schema fetching +ssl._create_default_https_context = ssl._create_unverified_context + class BulkItems(BaseModel): """Validation model for bulk-items endpoint request""" diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index 641de3f2..a4dbe1ec 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -198,38 +198,6 @@ ] ], }, - "proj:projjson": { - "id": {"code": 4326, "authority": "EPSG"}, - "name": "WGS 84", - "type": "GeographicCRS", - "datum": { - "name": "World Geodetic System 1984", - "type": "GeodeticReferenceFrame", - "ellipsoid": { - "name": "WGS 84", - "semi_major_axis": 6378137, - "inverse_flattening": 298.257223563, - }, - }, - "$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", - "coordinate_system": { - "axis": [ - { - "name": "Geodetic latitude", - "unit": "degree", - "direction": "north", - "abbreviation": "Lat", - }, - { - "name": "Geodetic longitude", - "unit": "degree", - "direction": "east", - "abbreviation": "Lon", - }, - ], - "subtype": "ellipsoidal", - }, - }, "proj:transform": [0.1, 0.0, -180.0, 0.0, -0.1, 90.0, 0.0, 0.0, 1.0], }, "rendered_preview": { @@ -286,9 +254,6 @@ def test_environ(): os.environ["POSTGRES_HOST_WRITER"] = "0.0.0.0" os.environ["POSTGRES_PORT"] = "5439" - import ssl - - ssl._create_default_https_context = ssl._create_unverified_context def override_validated_token(): diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index f8b0d880..1a232b13 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -58,7 +58,7 @@ async def test_post_valid_collection(self, api_client, valid_stac_collection): response = await api_client.post( collections_endpoint, json=valid_stac_collection ) - assert response.status_code == 201 + assert response.status_code in [201, 409] @pytest.mark.asyncio async def test_post_invalid_item(self, api_client, invalid_stac_item): @@ -89,7 +89,7 @@ async def test_post_valid_item(self, api_client, valid_stac_item, collection_in_ response = await api_client.post( items_endpoint.format(collection_id), json=valid_stac_item ) - assert response.status_code == 201 + assert response.status_code in [201, 409] # 201 for new, 409 for existing @pytest.mark.asyncio async def test_post_invalid_bulk_items(self, api_client, invalid_stac_item): @@ -207,7 +207,10 @@ async def test_tenant_landing_page_customization(self, api_client): rel = link.get("rel") href = link.get("href", "") - if rel in excluded_rels: + # Check if rel should be excluded (exact match or contains "queryables") + should_exclude = rel in excluded_rels or "queryables" in rel + + if should_exclude: assert ( "/fake-tenant/" not in href ), f"Excluded rel '{rel}' incorrectly contains tenant: {href}" From 6283a51cee000ced04c4dd40962ff86da13b9615 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 13:23:58 -0700 Subject: [PATCH 29/33] fix: formatting --- stac_api/runtime/tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index a4dbe1ec..4b0463e0 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -255,7 +255,6 @@ def test_environ(): os.environ["POSTGRES_PORT"] = "5439" - def override_validated_token(): """ Mock function to override validated token dependency. From 0ae96094bad803e4608b6af8dcba180d090a1f35 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Wed, 17 Sep 2025 13:33:37 -0700 Subject: [PATCH 30/33] fix: see if ssl disabling is necessary --- stac_api/runtime/src/validation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/stac_api/runtime/src/validation.py b/stac_api/runtime/src/validation.py index 79990880..b50d5c3a 100644 --- a/stac_api/runtime/src/validation.py +++ b/stac_api/runtime/src/validation.py @@ -2,7 +2,6 @@ import json import re -import ssl from typing import Dict from pydantic import BaseModel, Field @@ -14,9 +13,6 @@ from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware -# Disable SSL verification for external schema fetching -ssl._create_default_https_context = ssl._create_unverified_context - class BulkItems(BaseModel): """Validation model for bulk-items endpoint request""" From e642dc36ce20a81e046598d6e561084c0c6beb3c Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:05:54 -0700 Subject: [PATCH 31/33] fix: remove custom landing page, update tests, remove trailing slashes in evaluation --- stac_api/runtime/src/tenant_middleware.py | 52 +++++++++++++++++++++ stac_api/runtime/src/tenant_routes.py | 55 +---------------------- stac_api/runtime/tests/test_extensions.py | 39 +--------------- 3 files changed, 54 insertions(+), 92 deletions(-) diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py index be4d40f5..f5d9519c 100644 --- a/stac_api/runtime/src/tenant_middleware.py +++ b/stac_api/runtime/src/tenant_middleware.py @@ -31,7 +31,11 @@ async def dispatch(self, request: Request, call_next): """Processes incoming requests and extracts the tenant identifier from the URL""" try: + if self._should_skip_tenant_processing(request): + return await call_next(request) + tenant = self._extract_tenant(request) + logger.info(f"Extracted tenant is {tenant}") tenant_context = ( TenantContext( @@ -78,9 +82,57 @@ async def dispatch(self, request: Request, call_next): logger.error(f"Tenant middleware error: {str(e)}") raise + def _should_skip_tenant_processing(self, request: Request) -> bool: + """Check if tenant processing should be skipped for this request""" + path = request.url.path + logger.info(f"Tenant middleware processing path: {path}") + + # handles both local (no prefix) and production (/api/stac/ prefix) environments + if path.startswith("/api/stac/"): + if path == "/api/stac/" or path == "/api/stac": + logger.info(f"Skipping tenant processing - root STAC API: {path}") + return True + + path_parts = path.replace("/api/stac/", "").split("/") + else: + path_parts = path.lstrip("/").split("/") + + logger.info(f"Path parts: {path_parts}") + + if not path_parts or not path_parts[0]: + logger.info(f"Skipping tenant processing - empty path parts: {path}") + return True + + standard_endpoints = { + "collections", + "conformance", + "search", + "queryables", + "openapi.json", + "docs", + "favicon.ico", + "health", + "ping", + } + + first_part = path_parts[0].rstrip("/") + logger.info( + f"First part: '{path_parts[0]}', stripped: '{first_part}', in standard_endpoints: {first_part in standard_endpoints}" + ) + + if first_part in standard_endpoints: + logger.info( + f"Skipping tenant processing for standard endpoint: {first_part}" + ) + return True + + logger.info(f"Processing as tenant: {first_part}") + return False + def _extract_tenant(self, request: Request) -> Optional[str]: """Extracts the tenant identifier from the URL""" path = request.url.path + logger.info(f"Extracting tenant from request path {path}") if path.startswith("/api/stac/"): path_parts = path.replace("/api/stac/", "").split("/") if path_parts and path_parts[0]: diff --git a/stac_api/runtime/src/tenant_routes.py b/stac_api/runtime/src/tenant_routes.py index 3b519799..2c1f02d0 100644 --- a/stac_api/runtime/src/tenant_routes.py +++ b/stac_api/runtime/src/tenant_routes.py @@ -1,10 +1,9 @@ """ Tenant Route Handler """ import json import logging -from typing import Any, Dict, Optional +from typing import Dict, Optional from fastapi import APIRouter, HTTPException, Path, Query, Request -from fastapi.responses import ORJSONResponse from stac_fastapi.types.stac import Item, ItemCollection from .tenant_client import TenantAwareVedaCrudClient @@ -20,24 +19,6 @@ def __init__(self, client: TenantAwareVedaCrudClient): """Initializes tenant-aware route handler""" self.client = client - async def get_tenant_collections( - self, - request: Request, - tenant: str = Path(..., description="Tenant identifier"), - ) -> Dict[str, Any]: - """Get all collections belonging to a tenant""" - - logger.info(f"Getting collections for tenant: {tenant}") - - try: - collections = await self.client.get_tenant_collections( - request, tenant=tenant - ) - return collections - except Exception as e: - logger.error(f"Error getting collections for tenant {tenant}: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") - async def get_tenant_collection( self, request: Request, @@ -192,24 +173,6 @@ async def post_tenant_search( logger.error(f"Error performing POST search for tenant {tenant}: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") - async def get_tenant_landing_page( - self, - request: Request, - tenant: str = Path(..., description="Tenant identifier"), - ) -> ORJSONResponse: - """Get landing page for a specific tenant""" - - logger.info(f"Getting landing page for tenant: {tenant}") - - try: - landing_page = await self.client.landing_page(request, tenant=tenant) - return ORJSONResponse(content=landing_page) - except HTTPException: - raise - except Exception as e: - logger.error(f"Error getting landing page for tenant {tenant}: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") - def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: """Create tenant-specific router""" @@ -219,14 +182,6 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: logger.info("Creating tenant router with routes") - router.add_api_route( - "/{tenant}/collections", - handler.get_tenant_collections, - methods=["GET"], - summary="Get collections for tenant", - description="Retrieve all collections for a specific tenant", - ) - router.add_api_route( "/{tenant}/collections/{collection_id}", handler.get_tenant_collection, @@ -268,13 +223,5 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: description="Search items for a tenant using POST method", ) - router.add_api_route( - "/{tenant}/", - handler.get_tenant_landing_page, - methods=["GET"], - summary="Get tenant landing page", - description="Retrieve landing page for a specific tenant", - ) - logger.info(f"Created tenant router with {len(router.routes)} routes") return router diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index 1a232b13..0aa68ff2 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -16,6 +16,7 @@ - /Collections search by id and free text search """ + import pytest collections_endpoint = "/collections" @@ -184,44 +185,6 @@ async def test_get_collections_by_tenant(self, api_client, collection_in_db): assert response_data["collections"][0]["id"] == collection_id - @pytest.mark.asyncio - async def test_tenant_landing_page_customization(self, api_client): - """ - Test that tenant landing page is properly customized for tenant - """ - response = await api_client.get("/fake-tenant/") - assert response.status_code == 200 - - landing_page = response.json() - assert "FAKE-TENANT" in landing_page["title"] - - excluded_rels = [ - "self", - "root", - "service-desc", - "service-doc", - "conformance", - "queryables", - ] - for link in landing_page.get("links", []): - rel = link.get("rel") - href = link.get("href", "") - - # Check if rel should be excluded (exact match or contains "queryables") - should_exclude = rel in excluded_rels or "queryables" in rel - - if should_exclude: - assert ( - "/fake-tenant/" not in href - ), f"Excluded rel '{rel}' incorrectly contains tenant: {href}" - print(f"Excluded rel '{rel}' correctly has no tenant: {href}") - else: - if href.startswith("/api/stac") or "api/stac" in href: - assert ( - "/fake-tenant/" in href - ), f"Included rel '{rel}' does not have tenant: {href}" - print(f"Included rel '{rel}' correctly has tenant: {href}") - @pytest.mark.asyncio async def test_get_collection_by_id_with_tenant(self, api_client, collection_in_db): """ From 74d37586f4325b96f2cdb4d26abd359a70acc91e Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:32:06 -0700 Subject: [PATCH 32/33] fix: add back import, add route to get collections, add another condition for skipping tenant processing --- .../data/noaa-emergency-response.json | 3 +++ stac_api/runtime/src/tenant_middleware.py | 5 ++++ stac_api/runtime/src/tenant_routes.py | 27 ++++++++++++++++++- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/.github/workflows/data/noaa-emergency-response.json b/.github/workflows/data/noaa-emergency-response.json index dd9aabd4..5850e5ee 100644 --- a/.github/workflows/data/noaa-emergency-response.json +++ b/.github/workflows/data/noaa-emergency-response.json @@ -3,6 +3,9 @@ "title": "NOAA Emergency Response Imagery", "description": "NOAA Emergency Response Imagery hosted on AWS Public Dataset.", "stac_version": "1.0.0", + "properties": { + "tenant": "test-tenant" + }, "license": "public-domain", "links": [], "extent": { diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py index f5d9519c..f05613c9 100644 --- a/stac_api/runtime/src/tenant_middleware.py +++ b/stac_api/runtime/src/tenant_middleware.py @@ -120,6 +120,11 @@ def _should_skip_tenant_processing(self, request: Request) -> bool: f"First part: '{path_parts[0]}', stripped: '{first_part}', in standard_endpoints: {first_part in standard_endpoints}" ) + # if the path is exactly a standard endpoint with trailing slash, skip tenant processing + if len(path_parts) == 2 and path_parts[1] == "" and first_part in standard_endpoints: + logger.info(f"Skipping tenant processing for standard endpoint with trailing slash: {first_part}/") + return True + if first_part in standard_endpoints: logger.info( f"Skipping tenant processing for standard endpoint: {first_part}" diff --git a/stac_api/runtime/src/tenant_routes.py b/stac_api/runtime/src/tenant_routes.py index 2c1f02d0..07b4a79f 100644 --- a/stac_api/runtime/src/tenant_routes.py +++ b/stac_api/runtime/src/tenant_routes.py @@ -1,7 +1,7 @@ """ Tenant Route Handler """ import json import logging -from typing import Dict, Optional +from typing import Any, Dict, Optional from fastapi import APIRouter, HTTPException, Path, Query, Request from stac_fastapi.types.stac import Item, ItemCollection @@ -19,6 +19,23 @@ def __init__(self, client: TenantAwareVedaCrudClient): """Initializes tenant-aware route handler""" self.client = client + async def get_tenant_collections( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + ) -> Dict[str, Any]: + """Get all collections belonging to a tenant""" + logger.info(f"Getting collections for tenant: {tenant}") + + try: + collections = await self.client.get_tenant_collections( + request, tenant=tenant + ) + return collections + except Exception as e: + logger.error(f"Error getting collections for tenant {tenant}: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + async def get_tenant_collection( self, request: Request, @@ -182,6 +199,14 @@ def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: logger.info("Creating tenant router with routes") + router.add_api_route( + "/{tenant}/collections", + handler.get_tenant_collections, + methods=["GET"], + summary="Get collections for tenant", + description="Retrieve all collections for a specific tenant", + ) + router.add_api_route( "/{tenant}/collections/{collection_id}", handler.get_tenant_collection, From e4ccf31829f375069262e83f68f0f6e799fd86e0 Mon Sep 17 00:00:00 2001 From: Jennifer Tran <12633533+botanical@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:54:27 -0700 Subject: [PATCH 33/33] fix: formatting --- stac_api/runtime/src/tenant_middleware.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py index f05613c9..937ee246 100644 --- a/stac_api/runtime/src/tenant_middleware.py +++ b/stac_api/runtime/src/tenant_middleware.py @@ -121,8 +121,14 @@ def _should_skip_tenant_processing(self, request: Request) -> bool: ) # if the path is exactly a standard endpoint with trailing slash, skip tenant processing - if len(path_parts) == 2 and path_parts[1] == "" and first_part in standard_endpoints: - logger.info(f"Skipping tenant processing for standard endpoint with trailing slash: {first_part}/") + if ( + len(path_parts) == 2 + and path_parts[1] == "" + and first_part in standard_endpoints + ): + logger.info( + f"Skipping tenant processing for standard endpoint with trailing slash: {first_part}/" + ) return True if first_part in standard_endpoints: