diff --git a/.env b/.env index 3d90b4c3..80a16628 100644 --- a/.env +++ b/.env @@ -44,3 +44,28 @@ AIOD_NGINX_PORT=80 #DATA STORAGE DATA_PATH=./data BACKUP_PATH=./data/backups + +#PROMETHEUS and GRAFANA +AIOD_PROMETHEUS_PORT=9090 +AIOD_GRAFANA_PORT=3000 + +# MONITORING STORAGE +PROMETHEUS_DATA_PATH=./data/prometheus +GRAFANA_DATA_PATH=./data/grafana + +# PROMETHEUS RETENTION +PROMETHEUS_RETENTION=7d + +# GRAFANA DEMOS AUTH +GF_AUTH_ANONYMOUS_ENABLED=true +GF_AUTH_ANONYMOUS_ORG_ROLE=Viewer +GF_SECURITY_ADMIN_PASSWORD=admin + +# GRAFANA +GRAFANA_PROMETHEUS_URL=http://prometheus:9090 + +GRAFANA_MYSQL_HOST=sqlserver +GRAFANA_MYSQL_PORT=3306 +GRAFANA_MYSQL_DB=aiod +GRAFANA_MYSQL_USER=root +GRAFANA_MYSQL_PASSWORD=${MYSQL_ROOT_PASSWORD} diff --git a/docker-compose.dev.yaml b/docker-compose.dev.yaml index 1d3d84e1..156fa520 100644 --- a/docker-compose.dev.yaml +++ b/docker-compose.dev.yaml @@ -58,3 +58,25 @@ services: stdin_open: true volumes: - ./src:/app:ro + + prometheus: + volumes: + - ${PROMETHEUS_DATA_PATH:-./data/prometheus}:/prometheus + command: + - --config.file=/etc/prometheus/prometheus.yml + - --storage.tsdb.path=/prometheus + - --storage.tsdb.retention.time=${PROMETHEUS_RETENTION:-7d} + + grafana: + environment: + - GF_AUTH_ANONYMOUS_ENABLED=${GF_AUTH_ANONYMOUS_ENABLED:-false} + - GF_AUTH_ANONYMOUS_ORG_ROLE=${GF_AUTH_ANONYMOUS_ORG_ROLE:-Viewer} + - GF_SECURITY_ADMIN_PASSWORD=${GF_SECURITY_ADMIN_PASSWORD:-admin} + - GRAFANA_PROMETHEUS_URL=${GRAFANA_PROMETHEUS_URL:-http://prometheus:9090} + - GRAFANA_MYSQL_HOST=${GRAFANA_MYSQL_HOST:-sqlserver} + - GRAFANA_MYSQL_PORT=${GRAFANA_MYSQL_PORT:-3306} + - GRAFANA_MYSQL_DB=${GRAFANA_MYSQL_DB:-aiod} + - GRAFANA_MYSQL_USER=${GRAFANA_MYSQL_USER:-root} + - GRAFANA_MYSQL_PASSWORD=${MYSQL_ROOT_PASSWORD} + volumes: + - ${GRAFANA_DATA_PATH:-./data/grafana}:/var/lib/grafana diff --git a/docker-compose.yaml b/docker-compose.yaml index fe21236c..3c944d78 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -237,6 +237,28 @@ services: es_logstash_setup: condition: service_completed_successfully + prometheus: + profiles: ["monitoring"] + image: prom/prometheus:latest + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml:ro + ports: + - "${AIOD_PROMETHEUS_PORT:-9090}:9090" + restart: unless-stopped + + grafana: + profiles: ["monitoring"] + image: grafana/grafana:latest + depends_on: + - prometheus + volumes: + - ./grafana/provisioning/datasources:/etc/grafana/provisioning/datasources:ro + - ./grafana/provisioning/dashboards:/etc/grafana/provisioning/dashboards:ro + - ./grafana/dashboards:/etc/grafana/dashboards:ro + ports: + - "${AIOD_GRAFANA_PORT:-3000}:3000" + restart: unless-stopped + taxonomy: profiles: ["taxonomy"] container_name: taxonomy diff --git a/docs/hosting/metrics.md b/docs/hosting/metrics.md new file mode 100644 index 00000000..cfb728ad --- /dev/null +++ b/docs/hosting/metrics.md @@ -0,0 +1,198 @@ +# Metrics & Monitoring + +## Overview + +This adds two kinds of observability to the REST API: + +* **Operational metrics (Prometheus):** requests/second, latencies, error rates, exposed at **`/metrics`** and scraped by Prometheus; visualized in Grafana. +* **Product usage (MySQL):** the middleware writes one row per “asset-shaped” request to **`asset_access_log`** so we can query **top assets** (popularity) and build dashboards. Returned via **`/stats/top/{resource_type}`**. + +Low-coupling design: a small middleware observes the path and logs access; routers are unchanged. Path parsing is centralized to handle version prefixes. + +--- + +## Components + +* **apiserver** — FastAPI app exposing: + + * **`/metrics`** (Prometheus exposition via `prometheus_fastapi_instrumentator`) + * **`/stats/top/{resource_type}`** (JSON; success hits only) +* **MySQL** — table `asset_access_log` stores per-request asset hits +* **Prometheus** — scrapes apiserver’s `/metrics` +* **Grafana** — visualizes Prometheus (traffic) + MySQL (popularity) + +--- + +## Endpoints (apiserver) + +* **GET `/metrics`** + Exposes Prometheus metrics. Example series: `http_requests_total`, `http_request_duration_seconds`, process/python metrics, etc. + +* **GET `/stats/top/{resource_type}?limit=10`** + Returns an array of objects: + + ```json + [ + { "asset_id": "data_p7v02a70CbBGKk29T8przBjf", "hits": 42 }, + { "asset_id": "data_g8912mLHg8i2hsJblKu6G78i", "hits": 17 } + ] + ``` + + * Reports only successful requests (status code 200). + * `resource_type` is something like `datasets`, `models`, etc. + +--- + +## What gets logged (middleware) + +“Asset-shaped” paths are logged after the response completes, i.e., any endpoint starting with e.g., `/datasets`, `/models`, including `/assets`. Access to other endpoints, such as `/metrics` or `/docs` do not get logged by the middleware. This also works if the API is deployed with a path prefix, and access is captured regardless of which version of the API is used (e.g., `/v2` or latest). The middleware does *not* log *who* accessed the log in any way (though the webserver itself does log incoming requests, these are not stored to the database). + +--- + +## Table schema: `asset_access_log` + +* `id` (PK) +* `asset_id` (string) — the identifier of the asset, e.g., `data_f8aa9...`. +* `resource_type` (string) — e.g. `datasets`, `models`, etc. +* `status` (int) — HTTP status code from the response +* `accessed_at` (UTC timestamp, indexed) + +--- + +## Where the code lives + +* Middleware: **`src/middleware/access_log.py`** +* Path parsing (version/deployment prefixes): **`src/middleware/path_parse.py`** +* Top-assets router: **`src/routers/access_stats_router.py`** +* Wiring (include router, add middleware, expose /metrics): **`src/main.py`** + +--- + +## Run it + +Start the API + monitoring stack (Prometheus, Grafana): + +```bash +# helper +scripts/up.sh monitoring + +# or directly +docker compose --env-file=.env --env-file=override.env \ + -f docker-compose.yaml -f docker-compose.dev.yaml \ + --profile monitoring up -d +``` + +Open: + +* API Docs: `http://localhost:8000/docs` +* Metrics: `http://localhost:8000/metrics` +* Prometheus: `http://localhost:${PROMETHEUS_HOST_PORT:-9090}` +* Grafana: `http://localhost:${GRAFANA_HOST_PORT:-3000}` + +Generate some traffic: + +```bash +curl -s http://localhost:8000/datasets/abc >/dev/null +curl -s http://localhost:8000/datasets/v1/1 >/dev/null +curl -s http://localhost:8000/v2/models/bert >/dev/null +``` + +Check top assets (datasets): + +```bash +curl -s "http://localhost:8000/stats/top/datasets?limit=5" | jq . +``` + +--- + +## Grafana: quick setup + +Configure two data sources: + +1. **Prometheus** + + * URL: `http://prometheus:9090` + +2. **MySQL** (popularity) + + * Host: `sqlserver` + * Port: `3306` + * Database: `aiod` + * User/password: from `.env` + +**PromQL (traffic/latency examples):** + +```promql +# Requests per endpoint (1m rate) +sum by (handler) (rate(http_requests_total[1m])) + +# P95 latency by handler (5m window) +histogram_quantile( + 0.95, + sum by (le, handler) (rate(http_request_duration_seconds_bucket[5m])) +) + +# Error rate (4xx/5xx) per endpoint +sum by (handler) (rate(http_requests_total{status=~"4..|5.."}[5m])) +``` + +**MySQL (popularity examples):** + +```sql +-- Top datasets (all time) +SELECT asset_id AS asset, COUNT(*) AS hits +FROM asset_access_log +WHERE resource_type='datasets' AND status=200 +GROUP BY asset +ORDER BY hits DESC +LIMIT 10; + +-- All assets by type +SELECT resource_type AS type, asset_id AS asset, COUNT(*) AS hits +FROM asset_access_log +WHERE status=200 +GROUP BY type, asset +ORDER BY hits DESC; + +-- Top assets last 24h +SELECT resource_type AS type, asset_id AS asset, COUNT(*) AS hits +FROM asset_access_log +WHERE status=200 AND accessed_at >= NOW() - INTERVAL 1 DAY +GROUP BY type, asset +ORDER BY hits DESC +LIMIT 20; +``` + +(Optional) Provision defaults in repo: + +``` +grafana/provisioning/datasources/datasources.yml +grafana/provisioning/dashboards/dashboards.yml +grafana/provisioning/dashboards/aiod-metrics.json +``` + +--- + +## Tests + +Focused middleware tests live under `src/tests/middleware/`: + +```bash +PYTHONPATH=src pytest -q \ + src/tests/middleware/test_path_parse.py \ + src/tests/middleware/test_access_log_middleware.py +``` + +They cover: + +* Path parsing of `/datasets/abc`, `/datasets/v1/1`, `/v2/models/bert`, etc. +* That asset hits are written for 200s and 404s. +* That excluded paths (e.g., `/metrics`) are ignored. + +--- + +## Which service exposes `/stats`? + +The **apiserver** (REST API) exposes `/stats/top/{resource_type}`. It’s mounted with the other routers in `src/main.py`. + +--- diff --git a/grafana/dashboards/AIOD-API-Metrics.json b/grafana/dashboards/AIOD-API-Metrics.json new file mode 100644 index 00000000..f7f4dba4 --- /dev/null +++ b/grafana/dashboards/AIOD-API-Metrics.json @@ -0,0 +1,48 @@ +{ + "title": "AIoD API Metrics", + "uid": "aiod-api-metrics", + "timezone": "browser", + "schemaVersion": 38, + "version": 2, + "refresh": "5s", + "panels": [ + { + "type": "timeseries", + "title": "Requests per endpoint", + "datasource": { "type": "prometheus", "uid": "prometheus" }, + "targets": [ + { + "expr": "sum by (handler) (rate(http_requests_total[1m]))", + "legendFormat": "{{handler}}" + } + ], + "gridPos": { "x": 0, "y": 0, "w": 24, "h": 9 } + }, + { + "type": "table", + "title": "Top assets per type (top 20 each)", + "datasource": { "type": "mysql", "uid": "mysql" }, + "targets": [ + { + "format": "table", + "rawSql": "WITH ranked AS (\n SELECT\n resource_type,\n asset_id,\n COUNT(*) AS hits,\n ROW_NUMBER() OVER (\n PARTITION BY resource_type\n ORDER BY COUNT(*) DESC\n ) AS r\n FROM asset_access_log\n WHERE status = 200\n GROUP BY resource_type, asset_id\n)\nSELECT\n resource_type AS type,\n asset_id AS asset,\n hits\nFROM ranked\nWHERE r <= 20\nORDER BY type, hits DESC;" + } + ], + "gridPos": { "x": 0, "y": 9, "w": 24, "h": 9 }, + "options": { "showHeader": true } + }, + { + "type": "table", + "title": "Top assets overall (top 20)", + "datasource": { "type": "mysql", "uid": "mysql" }, + "targets": [ + { + "format": "table", + "rawSql": "SELECT\n CONCAT(resource_type, '/', asset_id) AS identifier,\n COUNT(*) AS hits\nFROM asset_access_log\nWHERE status = 200\nGROUP BY resource_type, asset_id\nORDER BY hits DESC\nLIMIT 20;" + } + ], + "gridPos": { "x": 0, "y": 18, "w": 24, "h": 8 }, + "options": { "showHeader": true } + } + ] +} diff --git a/grafana/provisioning/dashboards/dashboards.yml b/grafana/provisioning/dashboards/dashboards.yml new file mode 100644 index 00000000..ee1b9403 --- /dev/null +++ b/grafana/provisioning/dashboards/dashboards.yml @@ -0,0 +1,10 @@ +apiVersion: 1 +providers: + - name: default + orgId: 1 + folder: "" + type: file + disableDeletion: false + updateIntervalSeconds: 10 + options: + path: /etc/grafana/dashboards diff --git a/grafana/provisioning/datasources/mysql.yml b/grafana/provisioning/datasources/mysql.yml new file mode 100644 index 00000000..eb7dca03 --- /dev/null +++ b/grafana/provisioning/datasources/mysql.yml @@ -0,0 +1,13 @@ +apiVersion: 1 +datasources: + - uid: mysql + name: API MySQL + type: mysql + access: proxy + url: ${GRAFANA_MYSQL_HOST}:${GRAFANA_MYSQL_PORT} + database: ${GRAFANA_MYSQL_DB} + user: ${GRAFANA_MYSQL_USER} + secureJsonData: + password: ${GRAFANA_MYSQL_PASSWORD} + isDefault: false + editable: false diff --git a/grafana/provisioning/datasources/prometheus.yml b/grafana/provisioning/datasources/prometheus.yml new file mode 100644 index 00000000..f0596276 --- /dev/null +++ b/grafana/provisioning/datasources/prometheus.yml @@ -0,0 +1,9 @@ +apiVersion: 1 +datasources: + - uid: prometheus + name: Prometheus + type: prometheus + access: proxy + url: ${GRAFANA_PROMETHEUS_URL} + isDefault: true + editable: false diff --git a/mkdocs.yaml b/mkdocs.yaml index a024b936..eff98ff5 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -46,6 +46,7 @@ nav: - 'Authentication': hosting/authentication.md - 'Connectors': hosting/connectors.md - 'Synchronization': hosting/synchronization.md + - 'Monitoring': hosting/metrics.md - 'Developer Resources': - developer/index.md - 'Authentication': developer/authentication.md diff --git a/prometheus.yml b/prometheus.yml new file mode 100644 index 00000000..bdb78c39 --- /dev/null +++ b/prometheus.yml @@ -0,0 +1,7 @@ +global: + scrape_interval: 1s +scrape_configs: + - job_name: 'aiod_rest_api' + metrics_path: /metrics + static_configs: + - targets: ['app:8000'] diff --git a/pyproject.toml b/pyproject.toml index b6147a3a..51d02867 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "mysql-connector-python==9.1.0", "elasticsearch==8.16.0", "jinja2==3.1.4", + "prometheus-fastapi-instrumentator==6.1.0", ] readme = "README.md" diff --git a/src/database/model/access/__init__.py b/src/database/model/access/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/database/model/access/access_log.py b/src/database/model/access/access_log.py new file mode 100644 index 00000000..dde5e12a --- /dev/null +++ b/src/database/model/access/access_log.py @@ -0,0 +1,17 @@ +from datetime import datetime, UTC +from sqlmodel import SQLModel, Field + +from database.model.field_length import IDENTIFIER_LENGTH + + +class AssetAccessLog(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "asset_access_log" + id: int | None = Field(default=None, primary_key=True) + asset_id: str = Field( + max_length=IDENTIFIER_LENGTH, schema_extra=dict(examples=["data_p7v02a70CbBGKk29T8przBjf"]) + ) + resource_type: str = Field( + max_length=IDENTIFIER_LENGTH, schema_extra=dict(examples=["Datasets", "Models"]) + ) + status: int = Field(description="HTTP Status code of the request.") + accessed_at: datetime = Field(default_factory=lambda: datetime.now(UTC), index=True) diff --git a/src/main.py b/src/main.py index a68af59d..87a91b9b 100644 --- a/src/main.py +++ b/src/main.py @@ -9,7 +9,7 @@ import logging from pathlib import Path -import pkg_resources +from importlib.metadata import version as pkg_version, PackageNotFoundError import uvicorn from fastapi import Depends, FastAPI, HTTPException from fastapi.responses import HTMLResponse @@ -43,6 +43,9 @@ bookmark_router, asset_router, ) +from prometheus_fastapi_instrumentator import Instrumentator +from middleware.access_log import AccessLogMiddleware +from routers.access_stats_router import create as create_access_stats_router from versioning import ( versions, add_version_to_openapi, @@ -97,6 +100,8 @@ def counts() -> dict: ): app.include_router(router.create(url_prefix, version)) + app.include_router(create_access_stats_router(url_prefix)) + def create_app() -> FastAPI: """Create the FastAPI application, complete with routes.""" @@ -120,8 +125,11 @@ def create_app() -> FastAPI: raise ValueError(f"dev.taxonomy must be a path to a file, but is {taxonomy_path!r}.") synchronize_taxonomy_from_file(taxonomy_file) - pyproject_toml = pkg_resources.get_distribution("aiod_metadata_catalogue") - app = build_app(url_prefix=DEV_CONFIG.get("url_prefix", ""), version=pyproject_toml.version) + try: + dist_version = pkg_version("aiod_metadata_catalogue") + except PackageNotFoundError: + dist_version = "dev" + app = build_app(url_prefix=DEV_CONFIG.get("url_prefix", ""), version=dist_version) return app @@ -149,23 +157,34 @@ def build_app(*, url_prefix: str = "", version: str = "dev"): version="latest", **kwargs, ) - add_routes(main_app, version=Version.LATEST) - main_app.add_exception_handler(HTTPException, http_exception_handler) - add_version_to_openapi(main_app, root_path=url_prefix) - - for version, info in versions.items(): - if info.retired: - continue - app = FastAPI( - title=f"AIoD Metadata Catalogue {version}", - version=f"{version}", - **kwargs, + versioned_apps = [ + ( + FastAPI( + title=f"AIoD Metadata Catalogue {version}", + version=f"{version}", + **kwargs, + ), + version, ) + for version, info in versions.items() + if not info.retired + ] + for app, version in [(main_app, Version.LATEST)] + versioned_apps: add_routes(app, version=version) app.add_exception_handler(HTTPException, http_exception_handler) add_deprecation_and_sunset_middleware(app) add_version_to_openapi(app, root_path=url_prefix) - main_app.mount(f"/{version}", app) + + Instrumentator().instrument(main_app).expose( + main_app, endpoint="/metrics", include_in_schema=False + ) + # Since all traffic goes through the main app, this middleware only + # needs to be registered with the main app and not the mounted apps. + main_app.add_middleware(AccessLogMiddleware) + + for app, _ in versioned_apps: + main_app.mount(f"/{app.version}", app) + return main_app diff --git a/src/middleware/__init__.py b/src/middleware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/middleware/access_log.py b/src/middleware/access_log.py new file mode 100644 index 00000000..edf744ac --- /dev/null +++ b/src/middleware/access_log.py @@ -0,0 +1,28 @@ +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +from database.session import DbSession +from database.model.access.access_log import AssetAccessLog +from middleware.path_parse import parse_asset_from_path + + +class AccessLogMiddleware(BaseHTTPMiddleware): + """Write one AssetAccessLog row for any asset route.""" + + async def dispatch(self, request: Request, call_next): + response: Response = await call_next(request) + + parsed = parse_asset_from_path(request.url.path) + if parsed: + resource_type, asset_id = parsed + entry = AssetAccessLog( + asset_id=asset_id, + resource_type=resource_type, + status=response.status_code, + ) + with DbSession() as sess: + sess.add(entry) + sess.commit() + + return response diff --git a/src/middleware/path_parse.py b/src/middleware/path_parse.py new file mode 100644 index 00000000..6a1c0d95 --- /dev/null +++ b/src/middleware/path_parse.py @@ -0,0 +1,52 @@ +import re +from typing import Optional + +from config import DEV_CONFIG +from routers.resource_routers import versioned_routers +from versioning import Version + +_asset_abbreviation_to_plural = { + r.resource_class.__abbreviation__: r.resource_name_plural + for v in Version + for r in versioned_routers[v] +} + +ADDITIONAL_INCLUDES = { + "assets", + "agents", + "ai_assets", + "ai_resources", +} +INCLUDE = set(_asset_abbreviation_to_plural.values()) | ADDITIONAL_INCLUDES + + +def parse_asset_from_path(path: str) -> Optional[tuple[str, str]]: + """If the path represents direct asset access, return the asset type and identifier. + + Direct asset access means that an asset is requested by its identifier, + either through the resource type's router or a different one (like the generic one). + If the path does not represent direct access, returns None. + + Examples of direct access requests: + /v2/datasets/123 -> ("datasets", "123") + /assets/datasets/123 -> ("datasets", "123") + /aiod-api/v10/models/bert -> ("models", "bert") + + Access to endpoints like `/stats`, `/metrics`, `/docs` and so on return None. + """ + prefix = f"/{DEV_CONFIG.get('url_prefix', '')}" + path = path.removeprefix(prefix).strip("/") + + version_match = r"v\d+" + asset_type_match = "|".join(f"(?:{asset_type})" for asset_type in INCLUDE) + identifier_match = "\w{3,4}_[a-zA-Z0-9]{24}" + path_match = f"({version_match})?/?({asset_type_match})/({identifier_match})" + if (match := re.match(path_match, path)) is None: + return None + + version, asset_type, identifier = match.groups() + if asset_type in ADDITIONAL_INCLUDES: + prefix, _ = identifier.split("_") + asset_type = _asset_abbreviation_to_plural[prefix] + + return asset_type, identifier diff --git a/src/routers/access_stats_router.py b/src/routers/access_stats_router.py new file mode 100644 index 00000000..1c0fc066 --- /dev/null +++ b/src/routers/access_stats_router.py @@ -0,0 +1,59 @@ +from fastapi import APIRouter, Query +from sqlalchemy import func +from sqlmodel import select + +from database.session import DbSession +from database.model.access.access_log import AssetAccessLog + + +def create(url_prefix: str = "") -> APIRouter: + router = APIRouter(prefix=f"{url_prefix}/stats", tags=["stats"]) + + @router.get("/top/{resource_type}") + def top_assets(resource_type: str, limit: int = Query(10, ge=1, le=1000)): + stmt = ( + select( + AssetAccessLog.asset_id, + func.count().label("hits"), + ) + .where( + AssetAccessLog.resource_type == resource_type, + AssetAccessLog.status == 200, + ) + .group_by(AssetAccessLog.asset_id) + .order_by(func.count().desc()) + .limit(limit) + ) + with DbSession() as s: + rows = s.exec(stmt).all() + + return [{"asset_id": r[0], "hits": int(r[1])} for r in rows] + + @router.get("/top/all") + def top_all(limit_per_type: int = Query(20, ge=1, le=1000)): + ranked = ( + select( + AssetAccessLog.resource_type, + AssetAccessLog.asset_id, + func.count().label("hits"), + func.row_number() + .over( + partition_by=AssetAccessLog.resource_type, + order_by=func.count().desc(), + ) + .label("rnk"), + ) + .where(AssetAccessLog.status == 200) + .group_by(AssetAccessLog.resource_type, AssetAccessLog.asset_id) + ).subquery("ranked") + + stmt = select(ranked.c.resource_type, ranked.c.asset_id, ranked.c.hits).where( + ranked.c.rnk <= limit_per_type + ) + + with DbSession() as s: + rows = s.exec(stmt).all() + + return [{"type": r[0], "asset": r[1], "hits": int(r[2])} for r in rows] + + return router diff --git a/src/tests/middleware/test_access_log_middleware.py b/src/tests/middleware/test_access_log_middleware.py new file mode 100644 index 00000000..f2dc3e71 --- /dev/null +++ b/src/tests/middleware/test_access_log_middleware.py @@ -0,0 +1,69 @@ +from fastapi import FastAPI +from starlette.testclient import TestClient +import pytest + +from middleware.access_log import AccessLogMiddleware + +# simple in-memory recorder instead of touching a real DB +class _FakeSession: + def __init__(self, store): self._store = store + def __enter__(self): return self + def __exit__(self, *a): return False + def add(self, entry): self._store.append(entry) + def commit(self): pass + + +def test_middleware_logs_asset_hit(monkeypatch): + app = FastAPI() + app.add_middleware(AccessLogMiddleware) + + @app.get("/datasets/data_foobar12foobar12foobar12") + def _ok(): return {"ok": True} + + written = [] + import middleware.access_log as m + monkeypatch.setattr(m, "DbSession", lambda: _FakeSession(written), raising=True) + + client = TestClient(app) + r = client.get("/datasets/data_foobar12foobar12foobar12") + assert r.status_code == 200 + + assert len(written) == 1 + entry = written[0] + assert entry.resource_type == "datasets" + assert entry.asset_id == "data_foobar12foobar12foobar12" + assert entry.status == 200 + + +def test_middleware_logs_404_asset(monkeypatch): + app = FastAPI() + app.add_middleware(AccessLogMiddleware) + + written = [] + import middleware.access_log as m + monkeypatch.setattr(m, "DbSession", lambda: _FakeSession(written), raising=True) + + client = TestClient(app) + r = client.get("/v2/ml_models/mdl_bertbertbertbertbertbert") + assert r.status_code == 404 + + assert len(written) == 1 + entry = written[0] + assert entry.resource_type == "ml_models" + assert entry.asset_id == "mdl_bertbertbertbertbertbert" + assert entry.status == 404 + +def test_middleware_ignores_non_asset(monkeypatch): + app = FastAPI() + app.add_middleware(AccessLogMiddleware) + + @app.get("/metrics") + def _metrics(): return "ok" + + written = [] + import middleware.access_log as m + monkeypatch.setattr(m, "DbSession", lambda: _FakeSession(written), raising=True) + + client = TestClient(app) + assert client.get("/metrics").status_code == 200 + assert written == [] diff --git a/src/tests/middleware/test_path_parse.py b/src/tests/middleware/test_path_parse.py new file mode 100644 index 00000000..5b52100d --- /dev/null +++ b/src/tests/middleware/test_path_parse.py @@ -0,0 +1,33 @@ +import pytest +from middleware.path_parse import parse_asset_from_path + +@pytest.mark.parametrize("path,expected", [ + # typed routes (asset_id has no API version prefix and no type prefix) + ("/datasets/data_foobar12foobar12foobar12", ("datasets", "data_foobar12foobar12foobar12")), + ("/datasets/data_foobar12foobar12foobar12/", ("datasets", "data_foobar12foobar12foobar12")), + ("/v2/datasets/data_foobar12foobar12foobar12", ("datasets", "data_foobar12foobar12foobar12")), + + # generic asset routes + ("/assets/data_foobar12foobar12foobar12", ("datasets", "data_foobar12foobar12foobar12")), + ("/assets/proj_foobar12foobar12foobar12", ("projects", "proj_foobar12foobar12foobar12")), + + # non-asset / excluded + ("/metrics", None), + ("/docs", None), + ("/v2/docs", None), + ("/counts/v1", None), + ("/", None), +]) +def test_parse_asset_from_path(path, expected): + assert parse_asset_from_path(path) == expected + + +@pytest.mark.parametrize("path,expected", [ + ("/aiod-api/v10/ml_models/mdl_foobar12foobar12foobar12", ("ml_models", "mdl_foobar12foobar12foobar12")), + ("/aiod-api/ml_models/mdl_foobar12foobar12foobar12", ("ml_models", "mdl_foobar12foobar12foobar12")), + ("/aiod-api/assets/mdl_foobar12foobar12foobar12", ("ml_models", "mdl_foobar12foobar12foobar12")), +]) +def test_parse_asset_from_path_with_prefix(path, expected): + from config import DEV_CONFIG + DEV_CONFIG['url_prefix'] = 'aiod-api' + assert parse_asset_from_path(path) == expected