Skip to content

Commit 03db106

Browse files
declark1PRATIBHA-Moogijoerunde
authored
feat: add OpenTelemetry tracing support to router (#55)
This PR enables OpenTelemetry tracing support at the router level for `prefill`, `generate`, and `generate_stream` methods. Signed-off-by: Daniel Clark <[email protected]> Signed-off-by: Joe Runde <[email protected]> Co-authored-by: PRATIBHA MOOGI <[email protected]> Co-authored-by: Daniel Clark <[email protected]> Co-authored-by: Joe Runde <[email protected]>
1 parent 8abb71d commit 03db106

File tree

8 files changed

+474
-85
lines changed

8 files changed

+474
-85
lines changed

Cargo.lock

Lines changed: 341 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

launcher/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ ctrlc = { version = "3.4.2", features = ["termination"] }
1111
nix = { version = "0.28.0", features = ["process", "signal"] }
1212
serde_json = "^1.0.114"
1313
tracing = "0.1.40"
14-
tracing-subscriber = { version = "0.3.18", features = ["json"] }
15-
uuid = { version = "1.7.0", features = ["v4", "fast-rng"] }
14+
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
15+
uuid = { version = "1.7.0", features = ["v4", "fast-rng"] }

launcher/src/main.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ struct Args {
8989
// Default for default_include_stop_seqs is true for now, for backwards compatibility
9090
#[clap(default_value = "true", long, env, action = clap::ArgAction::Set)]
9191
default_include_stop_seqs: bool,
92+
#[clap(long, env)]
93+
otlp_endpoint: Option<String>,
9294
}
9395

9496
fn main() -> ExitCode {
@@ -107,7 +109,6 @@ fn main() -> ExitCode {
107109

108110
// Pattern match configuration
109111
let args = Args::parse();
110-
111112
if args.json_output {
112113
tracing_subscriber::fmt()
113114
.json()
@@ -326,6 +327,12 @@ fn main() -> ExitCode {
326327
argv.push("--json-output".to_string());
327328
}
328329

330+
// OpenTelemetry
331+
if let Some(otlp_endpoint) = args.otlp_endpoint {
332+
argv.push("--otlp-endpoint".to_string());
333+
argv.push(otlp_endpoint);
334+
}
335+
329336
if args.output_special_tokens {
330337
argv.push("--output-special-tokens".into());
331338
}
@@ -400,6 +407,7 @@ fn main() -> ExitCode {
400407
};
401408
}
402409

410+
// Graceful termination
403411
terminate_gracefully(&mut webserver, shutdown.clone(), shutdown_receiver);
404412

405413
exit_code

router/Cargo.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ path = "src/main.rs"
1515

1616
[dependencies]
1717
axum = { version = "0.6.20", features = ["json"] }
18+
axum-tracing-opentelemetry = "0.10.0"
1819
text-generation-client = { path = "client" }
1920
clap = { version = "^4.5.2", features = ["derive", "env"] }
2021
futures = "^0.3.30"
@@ -38,12 +39,15 @@ tokio = { version = "1.36.0", features = ["rt", "rt-multi-thread", "parking_lot"
3839
tokio-rustls = "^0.25.0"
3940
rustls = "0.22.2"
4041
tracing = "^0.1.40"
41-
tracing-subscriber = { version = "0.3.18", features = ["json"] }
4242
prost = "^0.12.3"
4343
tonic = { version = "^0.11.0", features = ["tls"] }
44+
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
45+
tracing-opentelemetry = "0.19.0"
4446
tokio-stream ="^0.1.14"
4547
unicode-segmentation = "^1.11.0"
4648
unicode-truncate = "^0.2.0"
49+
opentelemetry = { version = "0.19.0", features = ["rt-tokio"] }
50+
opentelemetry-otlp = "0.12.0"
4751

4852
[build-dependencies]
49-
tonic-build = "^0.11.0"
53+
tonic-build = "^0.11.0"

router/client/src/client.rs

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,12 @@ impl Client {
5252
})
5353
}
5454

55-
/// Returns a list of uris or unix sockets of all shards
56-
#[instrument(skip(self))]
55+
// Returns a list of uris or unix sockets of all shards
56+
//#[instrument(skip(self))]
57+
// Below function is a method only used once during pod startup and not tied to any external requests/transactions, disabling otel instrumentation
5758
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
5859
let request = tonic::Request::new(ServiceDiscoveryRequest {});
59-
let response = self
60-
.stub
61-
.service_discovery(request)
62-
.instrument(info_span!("service_discovery"))
63-
.await?;
60+
let response = self.stub.service_discovery(request).await?;
6461
let urls = response
6562
.into_inner()
6663
.urls
@@ -75,26 +72,20 @@ impl Client {
7572
}
7673

7774
/// Clear the past generations cache
78-
#[instrument(skip(self))]
75+
//#[instrument(skip(self))]
76+
//Below function is a method only used once during pod startup and not tied to any external requests/transactions, disabling otel instrumentation
7977
pub async fn clear_cache(&mut self) -> Result<()> {
8078
let request = tonic::Request::new(ClearCacheRequest {});
81-
self.stub
82-
.clear_cache(request)
83-
.instrument(info_span!("clear_cache"))
84-
.await?;
79+
self.stub.clear_cache(request).await?;
8580
Ok(())
8681
}
8782

8883
/// Get shard model info
89-
#[instrument(skip(self))]
84+
// Below function is a method only used once during pod startup and not tied to any external requests/transactions, disabling otel instrumentation
85+
//#[instrument(skip(self))]
9086
pub async fn model_info(&mut self) -> Result<(ModelType, u32, bool, MemoryScalingModel)> {
9187
let request = tonic::Request::new(ModelInfoRequest {});
92-
let response = self
93-
.stub
94-
.model_info(request)
95-
.instrument(info_span!("model_info"))
96-
.await?
97-
.into_inner();
88+
let response = self.stub.model_info(request).await?.into_inner();
9889
ModelType::try_from(response.model_type)
9990
.map(|mt| {
10091
(
@@ -108,7 +99,7 @@ impl Client {
10899
}
109100

110101
/// Get model health
111-
#[instrument(skip(self))]
102+
//#[instrument(skip(self))]
112103
pub async fn health(&mut self) -> Result<HealthResponse> {
113104
let request = tonic::Request::new(HealthRequest {});
114105
let response = self.stub.health(request).await?.into_inner();
@@ -120,20 +111,15 @@ impl Client {
120111
pub async fn prefix_lookup(&mut self, prefix_id: String) -> Result<u32> {
121112
let mut request = tonic::Request::new(PrefixLookupRequest { prefix_id });
122113
request.set_timeout(PREFIX_LOOKUP_TIMEOUT);
123-
let response = self
124-
.stub
125-
.prefix_lookup(request)
126-
.instrument(info_span!("prefix_lookup"))
127-
.await?
128-
.into_inner();
114+
let response = self.stub.prefix_lookup(request).await?.into_inner();
129115
Ok(response.prefix_length)
130116
}
131117

132118
/// Generate one token for each request in the given batch
133119
///
134120
/// Returns first generated token for each request in the batch, id of the next cached batch,
135121
/// and input token info if requested
136-
#[instrument(skip(self))]
122+
#[instrument(skip_all, fields(batch_id = &batch.id))]
137123
pub async fn prefill(
138124
&mut self,
139125
batch: Batch,
@@ -143,12 +129,7 @@ impl Client {
143129
batch: Some(batch),
144130
to_prune,
145131
});
146-
let response = self
147-
.stub
148-
.prefill(request)
149-
.instrument(info_span!("generate"))
150-
.await?
151-
.into_inner();
132+
let response = self.stub.prefill(request).await?.into_inner();
152133
let result = response
153134
.result
154135
.ok_or_else(|| ClientError::Generation("Unexpected empty response".into()))?;
@@ -164,18 +145,13 @@ impl Client {
164145
/// Generate one token for each request in the given cached batch(es)
165146
///
166147
/// Returns next generated token of each request in the batches and id of the next cached batch
167-
#[instrument(skip(self))]
148+
//#[instrument(skip(self))] <You can uncomment it for getting traces at each next_token() level
168149
pub async fn next_token(
169150
&mut self,
170151
batches: Vec<CachedBatch>,
171152
) -> Result<Option<GenerateTokenResponse>> {
172153
let request = tonic::Request::new(NextTokenRequest { batches });
173-
let response = self
174-
.stub
175-
.next_token(request)
176-
.instrument(info_span!("generate_with_cache"))
177-
.await?
178-
.into_inner();
154+
let response = self.stub.next_token(request).await?.into_inner();
179155
Ok(response.result.map(|result| {
180156
(
181157
result.output_tokens,

router/src/grpc_server.rs

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use tonic::{
1111
transport::{Certificate, Identity, Server, ServerTlsConfig},
1212
Request, Response, Status,
1313
};
14-
use tracing::{info_span, instrument, Span};
14+
use tracing::{instrument, Span};
1515
use unicode_truncate::UnicodeTruncateStr;
1616

1717
use crate::{
@@ -99,11 +99,18 @@ impl GenerationService for GenerationServicer {
9999
#[instrument(
100100
skip_all,
101101
fields(
102+
request_id=tracing::field::Empty,
102103
input=?request.get_ref().requests.iter().map(|r| truncate(&r.text, 32)).collect::<Vec<Cow<'_,str>>>(),
103104
prefix_id=?request.get_ref().prefix_id,
104105
correlation_id=?request.metadata().get("x-correlation-id").map(|mv| mv.to_str().unwrap_or("<non-ascii>")).unwrap_or("<none>"),
105106
input_bytes=?request.get_ref().requests.iter().map(|r| r.text.len()).collect::<Vec<usize>>(),
106107
params=?request.get_ref().params,
108+
validation_time=tracing::field::Empty,
109+
queue_time=tracing::field::Empty,
110+
inference_time=tracing::field::Empty,
111+
time_per_token=tracing::field::Empty,
112+
total_time=tracing::field::Empty,
113+
input_toks=tracing::field::Empty,
107114
)
108115
)]
109116
async fn generate(
@@ -224,11 +231,18 @@ impl GenerationService for GenerationServicer {
224231
#[instrument(
225232
skip_all,
226233
fields(
234+
request_id=tracing::field::Empty,
227235
input=?truncate(request.get_ref().request.as_ref().map(|r| &*r.text).unwrap_or(""), 32),
228236
prefix_id=?request.get_ref().prefix_id,
229237
correlation_id=?request.metadata().get("x-correlation-id").map(|mv| mv.to_str().unwrap_or("<non-ascii>")).unwrap_or("<none>"),
230238
input_bytes=?request.get_ref().request.as_ref().map(|r| r.text.len()).unwrap_or(0),
231239
params=?request.get_ref().params,
240+
validation_time=tracing::field::Empty,
241+
queue_time=tracing::field::Empty,
242+
inference_time=tracing::field::Empty,
243+
time_per_token=tracing::field::Empty,
244+
total_time=tracing::field::Empty,
245+
input_toks=tracing::field::Empty,
232246
)
233247
)]
234248
async fn generate_stream(
@@ -416,8 +430,6 @@ fn log_response(
416430
kind_log: &str,
417431
request_id: Option<u64>,
418432
) {
419-
let span;
420-
let _enter;
421433
// Timings
422434
let total_time = Instant::now() - start_time;
423435
if let Some(times) = times.as_ref() {
@@ -429,17 +441,14 @@ fn log_response(
429441
.unwrap_or_else(|| Duration::new(0, 0));
430442

431443
// Tracing metadata
432-
span = info_span!(
433-
"",
434-
validation_time = ?validation_time,
435-
queue_time = ?queue_time,
436-
inference_time = ?inference_time,
437-
time_per_token = ?time_per_token,
438-
total_time = ?total_time,
439-
input_toks = input_tokens,
440-
request_id = request_id,
441-
);
442-
_enter = span.enter();
444+
let span = Span::current();
445+
span.record("request_id", request_id.unwrap_or_default());
446+
span.record("validation_time", format!("{validation_time:?}"));
447+
span.record("queue_time", format!("{queue_time:?}"));
448+
span.record("inference_time", format!("{inference_time:?}"));
449+
span.record("time_per_token", format!("{time_per_token:?}"));
450+
span.record("total_time", format!("{total_time:?}"));
451+
span.record("input_toks", input_tokens);
443452

444453
metrics::histogram!(
445454
"tgi_request_inference_duration",

router/src/main.rs

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@ use std::{
77

88
/// Text Generation Inference external gRPC server entrypoint
99
use clap::Parser;
10+
use opentelemetry::{
11+
global,
12+
sdk::{propagation::TraceContextPropagator, trace, trace::Sampler, Resource},
13+
KeyValue,
14+
};
15+
use opentelemetry_otlp::WithExportConfig;
1016
use text_generation_client::ShardedClient;
1117
use text_generation_router::{server, server::ServerRunArgs};
1218
use tokenizers::Tokenizer;
1319
use tracing::warn;
20+
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
1421

1522
/// App Configuration
1623
#[derive(Parser, Debug)]
@@ -50,6 +57,8 @@ struct Args {
5057
output_special_tokens: bool,
5158
#[clap(long, env)]
5259
default_include_stop_seqs: bool,
60+
#[clap(long, env)]
61+
otlp_endpoint: Option<String>,
5362
}
5463

5564
fn main() -> Result<(), std::io::Error> {
@@ -69,15 +78,6 @@ fn main() -> Result<(), std::io::Error> {
6978
// Get args
7079
let args = Args::parse();
7180

72-
if args.json_output {
73-
tracing_subscriber::fmt()
74-
.json()
75-
.with_current_span(false)
76-
.init();
77-
} else {
78-
tracing_subscriber::fmt().compact().init();
79-
}
80-
8181
// Validate args
8282
validate_args(&args);
8383

@@ -104,6 +104,7 @@ fn main() -> Result<(), std::io::Error> {
104104
.build()
105105
.unwrap()
106106
.block_on(async {
107+
init_logging(args.otlp_endpoint, args.json_output);
107108
// Instantiate sharded client from the master unix socket
108109
let mut sharded_client = ShardedClient::connect_uds(args.master_shard_uds_path)
109110
.await
@@ -202,3 +203,54 @@ fn write_termination_log(msg: &str) -> Result<(), io::Error> {
202203
writeln!(f, "{}", msg)?;
203204
Ok(())
204205
}
206+
207+
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
208+
let mut layers = Vec::new();
209+
210+
// STDOUT/STDERR layer
211+
let fmt_layer = tracing_subscriber::fmt::layer()
212+
.with_file(true)
213+
.with_line_number(true);
214+
215+
let fmt_layer = match json_output {
216+
true => fmt_layer.json().flatten_event(true).boxed(),
217+
false => fmt_layer.boxed(),
218+
};
219+
layers.push(fmt_layer);
220+
221+
// OpenTelemetry tracing layer
222+
if let Some(otlp_endpoint) = otlp_endpoint {
223+
global::set_text_map_propagator(TraceContextPropagator::new());
224+
225+
let tracer = opentelemetry_otlp::new_pipeline()
226+
.tracing()
227+
.with_exporter(
228+
opentelemetry_otlp::new_exporter()
229+
.tonic()
230+
.with_endpoint(otlp_endpoint),
231+
)
232+
.with_trace_config(
233+
trace::config()
234+
.with_resource(Resource::new(vec![KeyValue::new(
235+
"service.name",
236+
"text-generation-inference.router",
237+
)]))
238+
.with_sampler(Sampler::AlwaysOn),
239+
)
240+
.install_batch(opentelemetry::runtime::Tokio);
241+
242+
if let Ok(tracer) = tracer {
243+
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
244+
axum_tracing_opentelemetry::init_propagator().unwrap();
245+
};
246+
}
247+
248+
// Filter events with LOG_LEVEL
249+
let env_filter =
250+
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
251+
252+
tracing_subscriber::registry()
253+
.with(env_filter)
254+
.with(layers)
255+
.init();
256+
}

0 commit comments

Comments
 (0)