Skip to content

Commit 9e4a3a3

Browse files
authored
feat: add distributed tracing, prometheus and grafana (#4)
Adding otel sdk for distributed tracing and knitting with logging by adding info to log config with trace id and if it was sampled. The docker-compose just became huge, since all jaeger ecosystem was imported and it actually needs some time to load cassandra and be brought up. Locally, one can still use `make serve` and avoid the tracing ecosystem. * chore(ci): do not upload to codecov on test
1 parent 1c50f4c commit 9e4a3a3

File tree

15 files changed

+800
-755
lines changed

15 files changed

+800
-755
lines changed

.github/workflows/test.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,3 @@ jobs:
4848

4949
- name: Run tests
5050
run: make test
51-
52-
- name: Upload coverage reports to Codecov
53-
if: inputs.upload-coverage
54-
uses: codecov/codecov-action@v3
55-
env:
56-
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,48 @@ The Docker setup mounts the local `./models` directory to `/models` inside the c
169169

170170
If no models are found when starting the container, you'll be prompted to download the small model automatically.
171171

172+
## Distributed Tracing
173+
174+
Babeltron supports distributed tracing with OpenTelemetry and Jaeger. The application is configured to send traces to the OpenTelemetry Collector, which forwards them to Jaeger.
175+
176+
### Configuration
177+
178+
Tracing can be configured using the following environment variables:
179+
180+
- `OTLP_MODE`: The OpenTelemetry protocol mode (`otlp-grpc` or `otlp-http`)
181+
- `OTEL_SERVICE_NAME`: The name of the service in traces (default: `babeltron`)
182+
- `OTLP_GRPC_ENDPOINT`: The endpoint for the OpenTelemetry Collector using gRPC (default: `otel-collector:4317`)
183+
- `OTLP_HTTP_ENDPOINT`: The endpoint for the OpenTelemetry Collector using HTTP (default: `http://otel-collector:4318/v1/traces`)
184+
185+
### Accessing Jaeger UI
186+
187+
When running with Docker Compose, you can access the Jaeger UI at:
188+
189+
```
190+
http://localhost:16686
191+
```
192+
193+
### Tracing Features
194+
195+
The distributed tracing implementation provides insights into:
196+
197+
- Request flow through the API
198+
- Detailed timing of translation steps:
199+
- Tokenization
200+
- Model inference
201+
- Decoding
202+
- Error details and context
203+
- Cross-service communication
204+
205+
### Disabling Tracing
206+
207+
To disable tracing, set the `OTLP_GRPC_ENDPOINT` environment variable to `disabled`:
208+
209+
```yaml
210+
environment:
211+
- OTLP_GRPC_ENDPOINT=disabled
212+
```
213+
172214
## Contributing
173215
174216
Install pre-commit hooks with `make pre-commit-install` and refer to the [CONTRIBUTING.md](docs/CONTRIBUTING.md) file for more information.

babeltron/app/main.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
import os
2-
from contextlib import asynccontextmanager
31
from importlib.metadata import version
4-
from typing import AsyncIterator
52

63
from fastapi import FastAPI, Response
74
from fastapi.middleware.cors import CORSMiddleware
8-
from fastapi_cache import FastAPICache
9-
from fastapi_cache.backends.inmemory import InMemoryBackend
10-
from fastapi_cache.backends.redis import RedisBackend
11-
from redis import asyncio as aioredis
125

136
from babeltron.app.monitoring import PrometheusMiddleware, metrics_endpoint
7+
from babeltron.app.tracing import setup_jaeger
148
from babeltron.app.utils import include_routers
159

1610
try:
@@ -19,23 +13,6 @@
1913
__version__ = "0.1.0-dev"
2014

2115

22-
@asynccontextmanager
23-
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
24-
cache_url = os.environ.get("CACHE_URL", "")
25-
26-
if cache_url.startswith("in-memory"):
27-
FastAPICache.init(InMemoryBackend(), prefix="babeltron")
28-
print("Using in-memory cache")
29-
elif cache_url.startswith("redis"):
30-
redis = aioredis.from_url(cache_url)
31-
FastAPICache.init(RedisBackend(redis), prefix="babeltron")
32-
print("Using Redis cache")
33-
else:
34-
print("No cache_url provided, not using cache")
35-
36-
yield
37-
38-
3916
app = FastAPI(
4017
title="Babeltron Translation API",
4118
description="API for machine translation using NLLB models",
@@ -52,7 +29,6 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
5229
docs_url="/docs",
5330
redoc_url="/redoc",
5431
openapi_url="/openapi.json",
55-
lifespan=lifespan,
5632
)
5733

5834
# Configure CORS
@@ -64,6 +40,9 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
6440
allow_headers=["*"], # Allows all headers
6541
)
6642

43+
# Set up Jaeger tracing
44+
setup_jaeger(app)
45+
6746
# Include all routers
6847
include_routers(app)
6948

@@ -80,4 +59,8 @@ async def metrics():
8059
if __name__ == "__main__":
8160
import uvicorn
8261

83-
uvicorn.run(app, host="0.0.0.0", port=8000)
62+
log_config = uvicorn.config.LOGGING_CONFIG
63+
log_config["formatters"]["access"][
64+
"fmt"
65+
] = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] [trace_id=%(otelTraceID)s span_id=%(otelSpanID)s] - %(message)s"
66+
uvicorn.run(app, host="0.0.0.0", port=8000, log_config=log_config)

babeltron/app/routers/healthcheck.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ class HealthResponse(BaseModel):
1616
version: Optional[str] = None
1717

1818

19-
@router.get("/healthcheck", summary="Healthcheck")
2019
@router.get(
2120
"/healthz",
2221
summary="Check API health",
@@ -34,7 +33,6 @@ class ReadinessResponse(BaseModel):
3433
error: Optional[str] = None
3534

3635

37-
@router.get("/readiness", summary="Readiness Probe")
3836
@router.get(
3937
"/readyz",
4038
summary="Check API Readiness",

babeltron/app/routers/translate.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,41 @@
1+
import logging
12
import os
3+
import time
24

35
import torch
46
from fastapi import APIRouter, HTTPException, status
5-
from fastapi_cache.decorator import cache
7+
from opentelemetry import trace
68
from pydantic import BaseModel, Field
79
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
810

911
from babeltron.app.monitoring import track_dynamic_translation_metrics
10-
from babeltron.app.utils import ORJsonCoder, cache_key_builder, get_model_path
12+
from babeltron.app.utils import get_model_path
1113

1214
router = APIRouter(tags=["Translation"])
1315

1416
MODEL_COMPRESSION_ENABLED = os.environ.get(
1517
"MODEL_COMPRESSION_ENABLED", "true"
1618
).lower() in ("true", "1", "yes")
17-
CACHE_TTL_SECONDS = int(os.environ.get("CACHE_TTL_SECONDS", "3600"))
1819

1920
try:
2021
MODEL_PATH = get_model_path()
21-
print(f"Loading model from: {MODEL_PATH}")
22+
logging.info(f"Loading model from: {MODEL_PATH}")
2223
model = M2M100ForConditionalGeneration.from_pretrained(MODEL_PATH)
2324

2425
# Apply FP16 compression if enabled and supported
2526
if MODEL_COMPRESSION_ENABLED and torch.cuda.is_available():
26-
print("Applying FP16 model compression")
27+
logging.info("Applying FP16 model compression")
2728
model = model.half() # Convert to FP16 precision
2829
model = model.to("cuda") # Move to GPU
2930
elif MODEL_COMPRESSION_ENABLED:
30-
print("FP16 compression enabled but GPU not available, using CPU")
31+
logging.info("FP16 compression enabled but GPU not available, using CPU")
3132
else:
32-
print("Model compression disabled")
33+
logging.info("Model compression disabled")
3334

3435
tokenizer = M2M100Tokenizer.from_pretrained(MODEL_PATH)
35-
print("Model loaded successfully")
36+
logging.info("Model loaded successfully")
3637
except Exception as e:
37-
print(f"Error loading model: {e}")
38+
logging.error(f"Error loading model: {e}")
3839
model = None
3940
tokenizer = None
4041

@@ -77,31 +78,75 @@ class TranslationResponse(BaseModel):
7778
response_description="The translated text in the target language",
7879
status_code=status.HTTP_200_OK,
7980
)
80-
@cache(expire=CACHE_TTL_SECONDS, key_builder=cache_key_builder, coder=ORJsonCoder)
8181
@track_dynamic_translation_metrics()
8282
async def translate(request: TranslationRequest):
83+
# Get current span from context
84+
current_span = trace.get_current_span()
85+
# Add request attributes to the current span
86+
current_span.set_attribute("src_lang", request.src_lang)
87+
current_span.set_attribute("tgt_lang", request.tgt_lang)
88+
current_span.set_attribute("text_length", len(request.text))
89+
90+
logging.info(f"Translating text from {request.src_lang} to {request.tgt_lang}")
91+
8392
if model is None or tokenizer is None:
93+
current_span.set_attribute("error", "model_not_loaded")
94+
logging.error("Translation model not loaded")
8495
raise HTTPException(
8596
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
8697
detail="Translation model not loaded. Please check server logs.",
8798
)
8899

89100
try:
90-
tokenizer.src_lang = request.src_lang
91-
encoded_text = tokenizer(request.text, return_tensors="pt")
101+
tracer = trace.get_tracer(__name__)
102+
103+
with tracer.start_as_current_span("tokenization") as tokenize_span:
104+
start_time = time.time()
105+
tokenizer.src_lang = request.src_lang
106+
encoded_text = tokenizer(request.text, return_tensors="pt")
107+
tokenize_span.set_attribute(
108+
"token_count", encoded_text["input_ids"].shape[1]
109+
)
110+
tokenize_span.set_attribute(
111+
"duration_ms", (time.time() - start_time) * 1000
112+
)
92113

93-
# Move input to GPU if model is on GPU
94114
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
95-
encoded_text = {k: v.to("cuda") for k, v in encoded_text.items()}
96-
97-
generated_tokens = model.generate(
98-
**encoded_text, forced_bos_token_id=tokenizer.get_lang_id(request.tgt_lang)
99-
)
100-
translation = tokenizer.batch_decode(
101-
generated_tokens, skip_special_tokens=True
102-
)[0]
115+
with tracer.start_as_current_span("move_to_gpu") as gpu_span:
116+
start_time = time.time()
117+
encoded_text = {k: v.to("cuda") for k, v in encoded_text.items()}
118+
gpu_span.set_attribute("duration_ms", (time.time() - start_time) * 1000)
119+
120+
with tracer.start_as_current_span("model_inference") as inference_span:
121+
start_time = time.time()
122+
generated_tokens = model.generate(
123+
**encoded_text,
124+
forced_bos_token_id=tokenizer.get_lang_id(request.tgt_lang),
125+
)
126+
inference_time = time.time() - start_time
127+
inference_span.set_attribute("inference_time_seconds", inference_time)
128+
inference_span.set_attribute(
129+
"output_token_count", generated_tokens.shape[1]
130+
)
131+
inference_span.set_attribute("duration_ms", inference_time * 1000)
132+
133+
with tracer.start_as_current_span("decoding") as decode_span:
134+
start_time = time.time()
135+
translation = tokenizer.batch_decode(
136+
generated_tokens, skip_special_tokens=True
137+
)[0]
138+
decode_span.set_attribute("duration_ms", (time.time() - start_time) * 1000)
139+
140+
current_span.set_attribute("translation_length", len(translation))
141+
142+
logging.info(f"Translation completed: {len(translation)} characters")
103143
return {"translation": translation}
104144
except Exception as e:
145+
current_span.record_exception(e)
146+
current_span.set_attribute("error", str(e))
147+
current_span.set_attribute("error_type", type(e).__name__)
148+
149+
logging.error(f"Error translating text: {e}", exc_info=True)
105150
raise HTTPException(
106151
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
107152
detail=f"Error during translation: {str(e)}",

babeltron/app/tracing.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import logging
2+
import os
3+
4+
from fastapi import FastAPI
5+
from opentelemetry import trace
6+
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
7+
OTLPSpanExporter as OTLPSpanExporterGRPC,
8+
)
9+
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
10+
OTLPSpanExporter as OTLPSpanExporterHTTP,
11+
)
12+
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
13+
from opentelemetry.instrumentation.logging import LoggingInstrumentor
14+
from opentelemetry.sdk.trace import TracerProvider
15+
from opentelemetry.sdk.trace.export import BatchSpanProcessor
16+
17+
# Check if we're in a test environment
18+
IN_TEST = os.environ.get("PYTEST_CURRENT_TEST") is not None
19+
20+
OTLP_MODE = os.environ.get("OTLP_MODE", "otlp-grpc")
21+
OTLP_GRPC_ENDPOINT = os.environ.get("OTLP_GRPC_ENDPOINT", "otel-collector:4317")
22+
OTLP_HTTP_ENDPOINT = os.environ.get(
23+
"OTLP_HTTP_ENDPOINT", "http://otel-collector:4318/v1/traces"
24+
)
25+
26+
27+
def setup_jaeger(app: FastAPI, log_correlation: bool = True) -> None:
28+
# Skip setup if we're in a test environment
29+
if IN_TEST:
30+
logging.info("Skipping OpenTelemetry setup in test environment")
31+
return
32+
33+
# Check if tracing is disabled
34+
if OTLP_GRPC_ENDPOINT.lower() == "disabled":
35+
logging.info("OpenTelemetry tracing is disabled")
36+
return
37+
38+
tracer = TracerProvider()
39+
trace.set_tracer_provider(tracer)
40+
41+
if OTLP_MODE == "otlp-grpc":
42+
tracer.add_span_processor(
43+
BatchSpanProcessor(
44+
OTLPSpanExporterGRPC(endpoint=OTLP_GRPC_ENDPOINT, insecure=True)
45+
)
46+
)
47+
elif OTLP_MODE == "otlp-http":
48+
tracer.add_span_processor(
49+
BatchSpanProcessor(OTLPSpanExporterHTTP(endpoint=OTLP_HTTP_ENDPOINT))
50+
)
51+
else:
52+
tracer.add_span_processor(
53+
BatchSpanProcessor(
54+
OTLPSpanExporterGRPC(endpoint=OTLP_GRPC_ENDPOINT, insecure=True)
55+
)
56+
)
57+
58+
if log_correlation:
59+
LoggingInstrumentor().instrument(set_logging_format=True)
60+
61+
FastAPIInstrumentor.instrument_app(
62+
app,
63+
tracer_provider=tracer,
64+
excluded_urls="/metrics,/healthz,/readyz,/docs,/redoc,/openapi.json",
65+
)
66+
67+
logging.info(f"OpenTelemetry tracing enabled with endpoint: {OTLP_GRPC_ENDPOINT}")

0 commit comments

Comments
 (0)