@@ -24,24 +24,28 @@ import (
2424 "github.com/envoyproxy/ai-gateway/internal/extproc/translator"
2525 "github.com/envoyproxy/ai-gateway/internal/filterapi"
2626 "github.com/envoyproxy/ai-gateway/internal/internalapi"
27+ "github.com/envoyproxy/ai-gateway/internal/metrics"
2728 tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api"
2829)
2930
3031// CompletionsProcessorFactory returns a factory method to instantiate the completions processor.
31- func CompletionsProcessorFactory (_ interface {} ) ProcessorFactory {
32- return func (config * processorConfig , requestHeaders map [string ]string , logger * slog.Logger , _ tracing.Tracing , isUpstreamFilter bool ) (Processor , error ) {
32+ func CompletionsProcessorFactory (cm metrics. CompletionMetrics ) ProcessorFactory {
33+ return func (config * processorConfig , requestHeaders map [string ]string , logger * slog.Logger , tracing tracing.Tracing , isUpstreamFilter bool ) (Processor , error ) {
3334 logger = logger .With ("processor" , "completions" , "isUpstreamFilter" , fmt .Sprintf ("%v" , isUpstreamFilter ))
3435 if ! isUpstreamFilter {
3536 return & completionsProcessorRouterFilter {
3637 config : config ,
38+ tracer : tracing .CompletionTracer (),
3739 requestHeaders : requestHeaders ,
3840 logger : logger ,
41+ metrics : cm ,
3942 }, nil
4043 }
4144 return & completionsProcessorUpstreamFilter {
4245 config : config ,
4346 requestHeaders : requestHeaders ,
4447 logger : logger ,
48+ metrics : cm ,
4549 }, nil
4650 }
4751}
@@ -70,9 +74,15 @@ type completionsProcessorRouterFilter struct {
7074 // forcedStreamOptionIncludeUsage is set to true if the original request is a streaming request and has the
7175 // stream_options.include_usage=false. In that case, we force the option to be true to ensure that the token usage is calculated correctly.
7276 forcedStreamOptionIncludeUsage bool
77+ // tracer is the tracer used for requests.
78+ tracer tracing.CompletionTracer
79+ // span is the tracing span for this request, created in ProcessRequestBody.
80+ span tracing.CompletionSpan
7381 // upstreamFilterCount is the number of upstream filters that have been processed.
7482 // This is used to determine if the request is a retry request.
7583 upstreamFilterCount int
84+ // metrics tracking.
85+ metrics metrics.CompletionMetrics
7686}
7787
7888// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
@@ -96,7 +106,7 @@ func (c *completionsProcessorRouterFilter) ProcessResponseBody(ctx context.Conte
96106}
97107
98108// ProcessRequestBody implements [Processor.ProcessRequestBody].
99- func (c * completionsProcessorRouterFilter ) ProcessRequestBody (_ context.Context , rawBody * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
109+ func (c * completionsProcessorRouterFilter ) ProcessRequestBody (ctx context.Context , rawBody * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
100110 originalModel , body , err := parseOpenAICompletionBody (rawBody )
101111 if err != nil {
102112 return nil , fmt .Errorf ("failed to parse request body: %w" , err )
@@ -132,10 +142,17 @@ func (c *completionsProcessorRouterFilter) ProcessRequestBody(_ context.Context,
132142 c .originalRequestBody = body
133143 c .originalRequestBodyRaw = rawBody .Body
134144
135- // Create a header mutation without tracing .
145+ // Tracing may need to inject headers, so create a header mutation here .
136146 headerMutation := & extprocv3.HeaderMutation {
137147 SetHeaders : additionalHeaders ,
138148 }
149+ c .span = c .tracer .StartSpanAndInjectHeaders (
150+ ctx ,
151+ c .requestHeaders ,
152+ headerMutation ,
153+ body ,
154+ rawBody .Body ,
155+ )
139156
140157 return & extprocv3.ProcessingResponse {
141158 Response : & extprocv3.ProcessingResponse_RequestBody {
@@ -173,6 +190,10 @@ type completionsProcessorUpstreamFilter struct {
173190 forcedStreamOptionIncludeUsage bool
174191 // cost is the cost of the request that is accumulated during the processing of the response.
175192 costs translator.LLMTokenUsage
193+ // span is the tracing span for this request, inherited from the router filter.
194+ span tracing.CompletionSpan
195+ // metrics tracking.
196+ metrics metrics.CompletionMetrics
176197}
177198
178199// selectTranslator selects the translator based on the output schema.
@@ -193,6 +214,22 @@ func (c *completionsProcessorUpstreamFilter) selectTranslator(out filterapi.Vers
193214// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again
194215// to the extproc.
195216func (c * completionsProcessorUpstreamFilter ) ProcessRequestHeaders (ctx context.Context , _ * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
217+ defer func () {
218+ if err != nil {
219+ c .metrics .RecordRequestCompletion (ctx , false , c .requestHeaders )
220+ }
221+ }()
222+ // Start tracking metrics for this request.
223+ c .metrics .StartRequest (c .requestHeaders )
224+ // Set the original model from the request body before any overrides
225+ c .metrics .SetOriginalModel (c .originalRequestBody .Model )
226+ // Set the request model for metrics from the original model or override if applied.
227+ reqModel := c .originalRequestBody .Model
228+ if override := c .requestHeaders [internalapi .ModelNameHeaderKeyDefault ]; override != "" {
229+ reqModel = override
230+ }
231+ c .metrics .SetRequestModel (reqModel )
232+
196233 headerMutation , bodyMutation , err := c .translator .RequestBody (c .originalRequestBodyRaw , c .originalRequestBody , c .onRetry )
197234 if err != nil {
198235 return nil , fmt .Errorf ("failed to transform request: %w" , err )
@@ -241,7 +278,12 @@ func (c *completionsProcessorUpstreamFilter) ProcessRequestBody(context.Context,
241278}
242279
243280// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
244- func (c * completionsProcessorUpstreamFilter ) ProcessResponseHeaders (_ context.Context , headers * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
281+ func (c * completionsProcessorUpstreamFilter ) ProcessResponseHeaders (ctx context.Context , headers * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
282+ defer func () {
283+ if err != nil {
284+ c .metrics .RecordRequestCompletion (ctx , false , c .requestHeaders )
285+ }
286+ }()
245287 c .responseHeaders = headersToMap (headers )
246288 if enc := c .responseHeaders ["content-encoding" ]; enc != "" {
247289 c .responseEncoding = enc
@@ -263,7 +305,19 @@ func (c *completionsProcessorUpstreamFilter) ProcessResponseHeaders(_ context.Co
263305}
264306
265307// ProcessResponseBody implements [Processor.ProcessResponseBody].
266- func (c * completionsProcessorUpstreamFilter ) ProcessResponseBody (_ context.Context , body * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
308+ func (c * completionsProcessorUpstreamFilter ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
309+ // Track whether we need to record request completion on error.
310+ recordRequestCompletionErr := false
311+ defer func () {
312+ if err != nil || recordRequestCompletionErr {
313+ c .metrics .RecordRequestCompletion (ctx , false , c .requestHeaders )
314+ return
315+ }
316+ if body .EndOfStream {
317+ c .metrics .RecordRequestCompletion (ctx , true , c .requestHeaders )
318+ }
319+ }()
320+
267321 // Decompress the body if needed using common utility.
268322 decodingResult , err := decodeContentIfNeeded (body .Body , c .responseEncoding )
269323 if err != nil {
@@ -272,12 +326,20 @@ func (c *completionsProcessorUpstreamFilter) ProcessResponseBody(_ context.Conte
272326
273327 // Assume all responses have a valid status code header.
274328 if code , _ := strconv .Atoi (c .responseHeaders [":status" ]); ! isGoodStatusCode (code ) {
329+ recordRequestCompletionErr = true
275330 var headerMutation * extprocv3.HeaderMutation
276331 var bodyMutation * extprocv3.BodyMutation
277332 headerMutation , bodyMutation , err = c .translator .ResponseError (c .responseHeaders , decodingResult .reader )
278333 if err != nil {
279334 return nil , fmt .Errorf ("failed to transform response error: %w" , err )
280335 }
336+ if c .span != nil {
337+ b := bodyMutation .GetBody ()
338+ if b == nil {
339+ b = body .Body
340+ }
341+ c .span .EndSpanOnError (code , b )
342+ }
281343 return & extprocv3.ProcessingResponse {
282344 Response : & extprocv3.ProcessingResponse_ResponseBody {
283345 ResponseBody : & extprocv3.BodyResponse {
@@ -290,11 +352,14 @@ func (c *completionsProcessorUpstreamFilter) ProcessResponseBody(_ context.Conte
290352 }, nil
291353 }
292354
293- headerMutation , bodyMutation , tokenUsage , responseModel , err := c .translator .ResponseBody (c .responseHeaders , decodingResult .reader , body .EndOfStream )
355+ headerMutation , bodyMutation , tokenUsage , responseModel , err := c .translator .ResponseBody (c .responseHeaders , decodingResult .reader , body .EndOfStream , c . span )
294356 if err != nil {
295357 return nil , fmt .Errorf ("failed to transform response: %w" , err )
296358 }
297359
360+ // Set the response model for metrics
361+ c .metrics .SetResponseModel (responseModel )
362+
298363 // Remove content-encoding header if original body encoded but was mutated in the processor.
299364 headerMutation = removeContentEncodingIfNeeded (headerMutation , bodyMutation , decodingResult .isEncoded )
300365
@@ -314,6 +379,18 @@ func (c *completionsProcessorUpstreamFilter) ProcessResponseBody(_ context.Conte
314379 c .costs .OutputTokens += tokenUsage .OutputTokens
315380 c .costs .TotalTokens += tokenUsage .TotalTokens
316381
382+ // Record metrics.
383+ if c .stream {
384+ // Token latency is only recorded for streaming responses
385+ c .metrics .RecordTokenLatency (ctx , tokenUsage .OutputTokens , body .EndOfStream , c .requestHeaders )
386+ // Emit usage once at end-of-stream using final totals.
387+ if body .EndOfStream {
388+ c .metrics .RecordTokenUsage (ctx , c .costs .InputTokens , c .costs .OutputTokens , c .requestHeaders )
389+ }
390+ } else {
391+ c .metrics .RecordTokenUsage (ctx , tokenUsage .InputTokens , tokenUsage .OutputTokens , c .requestHeaders )
392+ }
393+
317394 // Log the response model for debugging
318395 if responseModel != "" {
319396 c .logger .Debug ("completion response model" , "model" , responseModel )
@@ -324,18 +401,32 @@ func (c *completionsProcessorUpstreamFilter) ProcessResponseBody(_ context.Conte
324401 if err != nil {
325402 return nil , fmt .Errorf ("failed to build dynamic metadata: %w" , err )
326403 }
404+ // Merge token latency metadata if streaming.
405+ if c .stream {
406+ c .mergeWithTokenLatencyMetadata (resp .DynamicMetadata )
407+ }
408+ }
409+
410+ if body .EndOfStream && c .span != nil {
411+ c .span .EndSpan ()
327412 }
328413
329414 return resp , nil
330415}
331416
332417// SetBackend implements [Processor.SetBackend].
333- func (c * completionsProcessorUpstreamFilter ) SetBackend (_ context.Context , b * filterapi.Backend , backendHandler backendauth.Handler , routeProcessor Processor ) (err error ) {
418+ func (c * completionsProcessorUpstreamFilter ) SetBackend (ctx context.Context , b * filterapi.Backend , backendHandler backendauth.Handler , routeProcessor Processor ) (err error ) {
419+ defer func () {
420+ if err != nil {
421+ c .metrics .RecordRequestCompletion (ctx , false , c .requestHeaders )
422+ }
423+ }()
334424 rp , ok := routeProcessor .(* completionsProcessorRouterFilter )
335425 if ! ok {
336426 panic ("BUG: expected routeProcessor to be of type *completionsProcessorRouterFilter" )
337427 }
338428 rp .upstreamFilterCount ++
429+ c .metrics .SetBackend (b )
339430 c .modelNameOverride = b .ModelNameOverride
340431 c .backendName = b .Name
341432 c .originalRequestBody = rp .originalRequestBody
@@ -353,9 +444,22 @@ func (c *completionsProcessorUpstreamFilter) SetBackend(_ context.Context, b *fi
353444 c .requestHeaders [internalapi .ModelNameHeaderKeyDefault ] = c .modelNameOverride
354445 }
355446 rp .upstreamFilter = c
447+ c .span = rp .span
356448 return
357449}
358450
451+ func (c * completionsProcessorUpstreamFilter ) mergeWithTokenLatencyMetadata (metadata * structpb.Struct ) {
452+ timeToFirstTokenMs := c .metrics .GetTimeToFirstTokenMs ()
453+ interTokenLatencyMs := c .metrics .GetInterTokenLatencyMs ()
454+ innerVal := metadata .Fields [internalapi .AIGatewayFilterMetadataNamespace ].GetStructValue ()
455+ if innerVal == nil {
456+ innerVal = & structpb.Struct {Fields : make (map [string ]* structpb.Value )}
457+ metadata .Fields [internalapi .AIGatewayFilterMetadataNamespace ] = structpb .NewStructValue (innerVal )
458+ }
459+ innerVal .Fields ["token_latency_ttft" ] = & structpb.Value {Kind : & structpb.Value_NumberValue {NumberValue : timeToFirstTokenMs }}
460+ innerVal .Fields ["token_latency_itl" ] = & structpb.Value {Kind : & structpb.Value_NumberValue {NumberValue : interTokenLatencyMs }}
461+ }
462+
359463func parseOpenAICompletionBody (body * extprocv3.HttpBody ) (modelName string , rb * openai.CompletionRequest , err error ) {
360464 var openAIReq openai.CompletionRequest
361465 if err := json .Unmarshal (body .Body , & openAIReq ); err != nil {
0 commit comments