|
1 | 1 | import logging |
| 2 | +import os |
2 | 3 | from contextlib import asynccontextmanager |
3 | 4 |
|
4 | 5 | import urllib3 |
5 | 6 | from fastapi import FastAPI |
| 7 | +from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess |
6 | 8 |
|
7 | 9 | from cogstack_model_gateway.common.config import load_config |
8 | 10 | from cogstack_model_gateway.common.db import DatabaseManager |
9 | 11 | from cogstack_model_gateway.common.logging import configure_logging |
10 | 12 | from cogstack_model_gateway.common.object_store import ObjectStoreManager |
11 | 13 | from cogstack_model_gateway.common.queue import QueueManager |
12 | 14 | from cogstack_model_gateway.common.tasks import TaskManager |
| 15 | +from cogstack_model_gateway.gateway.prometheus.metrics import gateway_requests_total |
13 | 16 | from cogstack_model_gateway.gateway.routers import models, tasks |
14 | 17 |
|
15 | 18 | log = logging.getLogger("cmg.gateway") |
16 | 19 |
|
17 | 20 |
|
| 21 | +def make_metrics_app(): |
| 22 | + """Create a registry for each process and aggregate metrics with MultiProcessCollector.""" |
| 23 | + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: |
| 24 | + raise RuntimeError( |
| 25 | + "Environment variable PROMETHEUS_MULTIPROC_DIR is not set. Please set it to a directory" |
| 26 | + " where Prometheus can store metrics data for multiprocess mode, e.g. /tmp/prometheus/" |
| 27 | + ) |
| 28 | + os.makedirs(os.environ["PROMETHEUS_MULTIPROC_DIR"], exist_ok=True) |
| 29 | + registry = CollectorRegistry() |
| 30 | + multiprocess.MultiProcessCollector(registry) |
| 31 | + return make_asgi_app(registry=registry) |
| 32 | + |
| 33 | + |
18 | 34 | @asynccontextmanager |
19 | 35 | async def lifespan(app: FastAPI): |
20 | 36 | """Setup gateway and initialize database, object store, queue, and task manager connections.""" |
@@ -73,6 +89,15 @@ async def lifespan(app: FastAPI): |
73 | 89 | app.include_router(models.router) |
74 | 90 | app.include_router(tasks.router) |
75 | 91 |
|
| 92 | +app.mount("/metrics", make_metrics_app()) |
| 93 | + |
| 94 | + |
| 95 | +@app.middleware("http") |
| 96 | +async def prometheus_request_counter(request, call_next): |
| 97 | + response = await call_next(request) |
| 98 | + gateway_requests_total.labels(method=request.method, endpoint=request.url.path).inc() |
| 99 | + return response |
| 100 | + |
76 | 101 |
|
77 | 102 | @app.get("/") |
78 | 103 | async def root(): |
|
0 commit comments