Skip to content

Commit 321ee6b

Browse files
committed
feat: add native TTFT, TPOT, ITL, and E2E latency tracking to framework
-Implements critical inference metrics directly within the IGW framework, removing the dependency on the SLO predictor plugin for observability. -Framework now natively tracks Time to First Token (TTFT), Time to Predict Output Token (TPOT), Inter-Token Latency (ITL), Decode Duration and End-to-End (E2E) latency for all inference requests. -Added tests to validate metrics tracking Signed-off-by: Sathvik <[email protected]>
1 parent e5b1a41 commit 321ee6b

File tree

5 files changed

+185
-3
lines changed

5 files changed

+185
-3
lines changed

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
218218

219219
if predictedLatencyCtx.ttft > 0 {
220220
logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTTFT", predictedLatencyCtx.ttft, "avgPredictedTTFT", predictedLatencyCtx.predictedTTFT)
221-
metrics.RecordRequestTTFT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.ttft/1000)
221+
// metrics.RecordRequestTTFT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.ttft/1000)
222222
metrics.RecordRequestPredictedTTFT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.predictedTTFT/1000)
223223
if predictedLatencyCtx.ttftSLO > 0 {
224224
metrics.RecordRequestTTFTWithSLO(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.ttft, predictedLatencyCtx.ttftSLO)
@@ -227,7 +227,7 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
227227

228228
if predictedLatencyCtx.avgTPOT > 0 {
229229
logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", predictedLatencyCtx.avgTPOT, "avgPredictedTPOT", predictedLatencyCtx.avgPredictedTPOT)
230-
metrics.RecordRequestTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT/1000)
230+
// metrics.RecordRequestTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT/1000)
231231
metrics.RecordRequestPredictedTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgPredictedTPOT/1000)
232232
if predictedLatencyCtx.avgTPOTSLO > 0 {
233233
metrics.RecordRequestTPOTWithSLO(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT, predictedLatencyCtx.avgTPOTSLO)

pkg/epp/handlers/response.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"strings"
24+
"time"
2425

2526
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2627
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
@@ -133,6 +134,22 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context,
133134
logger.Error(err, "error in HandleResponseBodyStreaming")
134135
}
135136

137+
// Record TTFT on the first token chunk.
138+
// We check for "data: " prefix to ensure it's a data chunk, and exclude "[DONE]" message.
139+
if reqCtx.GeneratedTokenCount == 0 && strings.Contains(responseText, streamingRespPrefix) && !strings.Contains(responseText, streamingEndMsg) {
140+
ttft := time.Since(reqCtx.RequestReceivedTimestamp).Seconds()
141+
reqCtx.TTFT = ttft
142+
metrics.RecordRequestTTFT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, ttft)
143+
reqCtx.GeneratedTokenCount = 1
144+
reqCtx.LastTokenTimestamp = time.Now()
145+
} else if reqCtx.GeneratedTokenCount > 0 && strings.Contains(responseText, streamingRespPrefix) && !strings.Contains(responseText, streamingEndMsg) {
146+
// Record ITL for subsequent tokens
147+
itl := time.Since(reqCtx.LastTokenTimestamp).Seconds()
148+
metrics.RecordRequestITL(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, itl)
149+
reqCtx.LastTokenTimestamp = time.Now()
150+
reqCtx.GeneratedTokenCount++
151+
}
152+
136153
// Parse usage on EVERY chunk to catch split streams (where usage and [DONE] are in different chunks).
137154
if resp := parseRespForUsage(ctx, responseText); resp.Usage.TotalTokens > 0 {
138155
reqCtx.Usage = resp.Usage
@@ -147,6 +164,22 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context,
147164
cachedToken = reqCtx.Usage.PromptTokenDetails.CachedTokens
148165
}
149166
metrics.RecordPromptCachedTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, cachedToken)
167+
168+
// Record Time Per Output Token
169+
// TPOT = (Total Duration - TTFT) / (OutputTokens - 1)
170+
if reqCtx.Usage.CompletionTokens > 1 && reqCtx.TTFT > 0 {
171+
totalDuration := time.Since(reqCtx.RequestReceivedTimestamp).Seconds()
172+
generationDuration := totalDuration - reqCtx.TTFT
173+
metrics.RecordRequestDecodeDuration(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, generationDuration)
174+
metrics.RecordRequestE2ELatency(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, totalDuration)
175+
176+
// Avoid division by zero just in case
177+
if count := float64(reqCtx.Usage.CompletionTokens - 1); count > 0 {
178+
avgTPOT := generationDuration / count
179+
reqCtx.TPOT = avgTPOT
180+
metrics.RecordRequestTPOT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, avgTPOT)
181+
}
182+
}
150183
}
151184
}
152185

pkg/epp/handlers/response_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"encoding/json"
2222
"testing"
23+
"time"
2324

2425
"github.com/google/go-cmp/cmp"
2526
"github.com/stretchr/testify/assert"
@@ -327,3 +328,68 @@ func TestGenerateResponseHeaders_Sanitization(t *testing.T) {
327328
assert.NotContains(t, gotHeaders, metadata.DestinationEndpointKey)
328329
assert.NotContains(t, gotHeaders, "content-length")
329330
}
331+
332+
func TestHandleResponseBodyModelStreaming_Metrics(t *testing.T) {
333+
t.Parallel()
334+
ctx := context.Background()
335+
336+
t.Run("TTFT Recording", func(t *testing.T) {
337+
server := &StreamingServer{director: &mockDirector{}}
338+
reqCtx := &RequestContext{
339+
RequestReceivedTimestamp: time.Now().Add(-100 * time.Millisecond),
340+
IncomingModelName: "model-a",
341+
TargetModelName: "model-b",
342+
}
343+
344+
chunk := `data: {"choices":[{"text":"First token"}]}`
345+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, chunk)
346+
347+
assert.Greater(t, reqCtx.TTFT, 0.0, "TTFT should be recorded and greater than 0")
348+
assert.Equal(t, 1, reqCtx.GeneratedTokenCount, "GeneratedTokenCount should be 1")
349+
assert.False(t, reqCtx.LastTokenTimestamp.IsZero(), "LastTokenTimestamp should be set")
350+
})
351+
352+
t.Run("ITL Recording", func(t *testing.T) {
353+
server := &StreamingServer{director: &mockDirector{}}
354+
reqCtx := &RequestContext{
355+
RequestReceivedTimestamp: time.Now().Add(-1 * time.Second),
356+
IncomingModelName: "model-a",
357+
TargetModelName: "model-b",
358+
// Simulate first token already received
359+
GeneratedTokenCount: 1,
360+
LastTokenTimestamp: time.Now().Add(-50 * time.Millisecond),
361+
TTFT: 0.1,
362+
}
363+
364+
chunk := `data: {"choices":[{"text":"Second token"}]}`
365+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, chunk)
366+
367+
// ITL is not stored in ReqCtx, but we can verify state updates
368+
assert.Equal(t, 2, reqCtx.GeneratedTokenCount, "GeneratedTokenCount should increment")
369+
assert.True(t, time.Since(reqCtx.LastTokenTimestamp) < 10*time.Millisecond, "LastTokenTimestamp should be updated to Now")
370+
})
371+
372+
t.Run("TPOT and E2E Recording", func(t *testing.T) {
373+
server := &StreamingServer{director: &mockDirector{}}
374+
reqCtx := &RequestContext{
375+
RequestReceivedTimestamp: time.Now().Add(-1 * time.Second),
376+
IncomingModelName: "model-a",
377+
TargetModelName: "model-b",
378+
TTFT: 0.1,
379+
GeneratedTokenCount: 10,
380+
}
381+
382+
// Usage that triggers TPOT calc
383+
chunk := `data: {"usage":{"prompt_tokens":5,"completion_tokens":11,"total_tokens":16}}` + "\n" + `data: [DONE]`
384+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, chunk)
385+
386+
assert.True(t, reqCtx.ResponseComplete, "Response should be complete")
387+
assert.Greater(t, reqCtx.TPOT, 0.0, "TPOT should be calculated")
388+
389+
// Expected TPOT calc: (TotalDuration - TTFT) / (CompletionTokens - 1)
390+
// TotalDuration ~ 1.0s, TTFT = 0.1s -> GenDuration ~ 0.9s
391+
// Tokens - 1 = 10
392+
// TPOT ~ 0.09
393+
assert.InDelta(t, 0.09, reqCtx.TPOT, 0.05, "TPOT should be approximately correct")
394+
})
395+
}

pkg/epp/handlers/server.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ type RequestContext struct {
9999

100100
Response *Response
101101

102+
// Metrics
103+
TTFT float64
104+
TPOT float64
105+
LastTokenTimestamp time.Time
106+
GeneratedTokenCount int
107+
102108
reqHeaderResp *extProcPb.ProcessingResponse
103109
reqBodyResp []*extProcPb.ProcessingResponse
104110
reqTrailerResp *extProcPb.ProcessingResponse
@@ -145,7 +151,8 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
145151
// Create request context to share states during life time of an HTTP request.
146152
// See https://github.com/envoyproxy/envoy/issues/17540.
147153
reqCtx := &RequestContext{
148-
RequestState: RequestReceived,
154+
RequestState: RequestReceived,
155+
RequestReceivedTimestamp: time.Now(),
149156
Request: &Request{
150157
Headers: make(map[string]string),
151158
Body: make(map[string]any),

pkg/epp/metrics/metrics.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ const (
5757
TypeTTFTPredictionDuration = "ttft_prediction_duration"
5858
TypeTTFTSLOViolation = "ttft_slo_violation"
5959
TypeTTFTSLOThreshold = "ttft_slo_threshold"
60+
61+
TypeITL = "itl"
62+
TypeDecodeDuration = "decode_duration"
63+
TypeE2ELatency = "e2e_latency"
6064
)
6165

6266
var (
@@ -176,6 +180,36 @@ var (
176180
ModelLabels,
177181
)
178182

183+
requestITL = prometheus.NewHistogramVec(
184+
prometheus.HistogramOpts{
185+
Subsystem: InferenceObjectiveComponent,
186+
Name: "request_itl_seconds",
187+
Help: metricsutil.HelpMsgWithStability("Inference model Inter-Token Latency distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
188+
Buckets: TPOTBuckets,
189+
},
190+
ModelLabels,
191+
)
192+
193+
requestDecodeDuration = prometheus.NewHistogramVec(
194+
prometheus.HistogramOpts{
195+
Subsystem: InferenceObjectiveComponent,
196+
Name: "request_decode_duration_seconds",
197+
Help: metricsutil.HelpMsgWithStability("Inference model Decode Duration distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
198+
Buckets: GeneralLatencyBuckets,
199+
},
200+
ModelLabels,
201+
)
202+
203+
requestE2ELatency = prometheus.NewHistogramVec(
204+
prometheus.HistogramOpts{
205+
Subsystem: InferenceObjectiveComponent,
206+
Name: "request_e2e_latency_seconds",
207+
Help: metricsutil.HelpMsgWithStability("Inference model E2E Latency (TTFT + Decode) distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
208+
Buckets: GeneralLatencyBuckets,
209+
},
210+
ModelLabels,
211+
)
212+
179213
sloViolationCounter = prometheus.NewCounterVec(
180214
prometheus.CounterOpts{
181215
Subsystem: InferenceObjectiveComponent,
@@ -443,6 +477,9 @@ func Register(customCollectors ...prometheus.Collector) {
443477
metrics.Registry.MustRegister(requestPredictedTTFT)
444478
metrics.Registry.MustRegister(requestTPOTPredictionDuration)
445479
metrics.Registry.MustRegister(requestTTFTPredictionDuration)
480+
metrics.Registry.MustRegister(requestITL)
481+
metrics.Registry.MustRegister(requestDecodeDuration)
482+
metrics.Registry.MustRegister(requestE2ELatency)
446483

447484
// Register SLO violation counters
448485
metrics.Registry.MustRegister(sloViolationCounter)
@@ -490,6 +527,9 @@ func Reset() {
490527
requestPredictedTTFT.Reset()
491528
requestTPOTPredictionDuration.Reset()
492529
requestTTFTPredictionDuration.Reset()
530+
requestITL.Reset()
531+
requestDecodeDuration.Reset()
532+
requestE2ELatency.Reset()
493533

494534
// Reset SLO violation counter
495535
sloViolationCounter.Reset()
@@ -667,6 +707,42 @@ func RecordRequestTTFTPredictionDuration(ctx context.Context, modelName, targetM
667707
return true
668708
}
669709

710+
// RecordRequestITL records the Inter-Token Latency.
711+
func RecordRequestITL(ctx context.Context, modelName, targetModelName string, itl float64) bool {
712+
if itl < 0 {
713+
log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "ITL value must be non-negative",
714+
"modelName", modelName, "targetModelName", targetModelName, "itl", itl)
715+
return false
716+
}
717+
requestITL.WithLabelValues(modelName, targetModelName).Observe(itl)
718+
inferenceGauges.WithLabelValues(modelName, targetModelName, TypeITL).Set(itl)
719+
return true
720+
}
721+
722+
// RecordRequestDecodeDuration records the Decode Duration.
723+
func RecordRequestDecodeDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool {
724+
if duration < 0 {
725+
log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Decode duration value must be non-negative",
726+
"modelName", modelName, "targetModelName", targetModelName, "duration", duration)
727+
return false
728+
}
729+
requestDecodeDuration.WithLabelValues(modelName, targetModelName).Observe(duration)
730+
inferenceGauges.WithLabelValues(modelName, targetModelName, TypeDecodeDuration).Set(duration)
731+
return true
732+
}
733+
734+
// RecordRequestE2ELatency records the E2E Latency (TTFT + Decode).
735+
func RecordRequestE2ELatency(ctx context.Context, modelName, targetModelName string, duration float64) bool {
736+
if duration < 0 {
737+
log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "E2E latency value must be non-negative",
738+
"modelName", modelName, "targetModelName", targetModelName, "duration", duration)
739+
return false
740+
}
741+
requestE2ELatency.WithLabelValues(modelName, targetModelName).Observe(duration)
742+
inferenceGauges.WithLabelValues(modelName, targetModelName, TypeE2ELatency).Set(duration)
743+
return true
744+
}
745+
670746
// RecordResponseSizes records the response sizes.
671747
func RecordResponseSizes(modelName, targetModelName string, size int) {
672748
responseSizes.WithLabelValues(modelName, targetModelName).Observe(float64(size))

0 commit comments

Comments
 (0)