Skip to content

Commit 95241ab

Browse files
committed
Make slo_aware_router modular via PredictionRequestBuilder interface and add public accessors
- Add PredictionRequestBuilder interface for customizable prediction/training request construction - Implement DefaultPredictionRequestBuilder preserving current monolithic behavior - Make requestBuilder injectable through Config with backward compatibility - Update bulkPredictWithMetrics and training helpers to use builder pattern - Add GetSchedulingResultForRequest, GetLastSeenMetricsForRequest, GetPrefixCacheScoresForRequest, GetRequestPrompt public accessors - Add GetRequestBuilder, GetLatencyPredictor accessors - Update ResponseStreaming to pass requestBuilder and pod to helpers - Update tests to use builder pattern Enables downstream projects to customize PodType field for disaggregated serving without exposing internal state or modifying core router behavior.
1 parent adb0ea3 commit 95241ab

File tree

8 files changed

+277
-46
lines changed

8 files changed

+277
-46
lines changed

pkg/epp/framework/interface/datalayer/endpoint.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ func (srv *ModelServer) GetAttributes() AttributeMap {
9090
return srv.attributes
9191
}
9292

93+
// Get retrieves an attribute value by key, forwarding to the underlying AttributeMap.
94+
func (srv *ModelServer) Get(key string) (Cloneable, bool) {
95+
return srv.attributes.Get(key)
96+
}
97+
98+
// Put stores an attribute value by key, forwarding to the underlying AttributeMap.
99+
func (srv *ModelServer) Put(key string, value Cloneable) {
100+
srv.attributes.Put(key, value)
101+
}
102+
103+
// Keys returns all attribute keys, forwarding to the underlying AttributeMap.
104+
func (srv *ModelServer) Keys() []string {
105+
return srv.attributes.Keys()
106+
}
107+
93108
func (srv *ModelServer) Clone() *ModelServer {
94109
clone := &ModelServer{
95110
attributes: srv.attributes.Clone(),

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper.go

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
3030
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
31+
framework "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
3132
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3233
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3334
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
@@ -120,7 +121,9 @@ func processFirstTokenForLatencyPrediction(
120121
ctx context.Context,
121122
predictor latencypredictor.PredictorInterface,
122123
streamingMode bool,
124+
requestBuilder PredictionRequestBuilder,
123125
predictedLatencyCtx *predictedLatencyCtx,
126+
pod framework.Endpoint,
124127
now time.Time,
125128
samplingMean float64,
126129
maxSampledTokens int,
@@ -139,7 +142,7 @@ func processFirstTokenForLatencyPrediction(
139142
targetPod := predictedLatencyCtx.targetMetadata
140143
prefixCacheScore := predictedLatencyCtx.prefixCacheScoresForEndpoints[targetPod.NamespacedName.Name]
141144
logger.V(logutil.DEBUG).Info("Recording TTFT training data", "ttft_ms", predictedLatencyCtx.ttft, "prefixCacheScore", prefixCacheScore)
142-
recordTTFTTrainingData(ctx, predictor, predictedLatencyCtx, m, now, prefixCacheScore)
145+
recordTTFTTrainingData(ctx, predictor, requestBuilder, predictedLatencyCtx, m, pod, now, prefixCacheScore)
143146

144147
if streamingMode {
145148
predictFirstTPOT(ctx, predictor, predictedLatencyCtx)
@@ -163,24 +166,26 @@ func initializeSampler(ctx context.Context, predictedLatencyCtx *predictedLatenc
163166
func recordTTFTTrainingData(
164167
ctx context.Context,
165168
predictor latencypredictor.PredictorInterface,
169+
requestBuilder PredictionRequestBuilder,
166170
predictedLatencyCtx *predictedLatencyCtx,
167171
m *fwkdl.Metrics,
172+
pod framework.Endpoint,
168173
now time.Time,
169174
prefixCacheScore float64,
170175
) {
171176
logger := log.FromContext(ctx)
172-
// Train TTFT
173-
entry := latencypredictor.TrainingEntry{
174-
KVCachePercentage: m.KVCacheUsagePercent,
175-
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
176-
ActualTTFT: predictedLatencyCtx.ttft,
177-
ActualTPOT: 0,
178-
Timestamp: now,
179-
NumRequestWaiting: m.WaitingQueueSize,
180-
NumRequestRunning: m.RunningRequestsSize,
181-
NumTokensGenerated: 0,
182-
PrefixCacheScore: prefixCacheScore,
183-
}
177+
// Build training entry using the builder
178+
entry := requestBuilder.BuildTrainingEntry(
179+
ctx,
180+
pod,
181+
m,
182+
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
183+
predictedLatencyCtx.ttft,
184+
0, // TTFT training
185+
now,
186+
0,
187+
prefixCacheScore,
188+
)
184189
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
185190
logger.V(logutil.DEBUG).Error(err, "record TTFT training failed")
186191
}
@@ -227,7 +232,9 @@ func predictFirstTPOT(
227232
func processTokenForLatencyPrediction(
228233
ctx context.Context,
229234
predictor latencypredictor.PredictorInterface,
235+
requestBuilder PredictionRequestBuilder,
230236
predictedLatencyCtx *predictedLatencyCtx,
237+
pod framework.Endpoint,
231238
now time.Time,
232239
samplingMean float64,
233240
maxSampledTokens int,
@@ -257,18 +264,18 @@ func processTokenForLatencyPrediction(
257264
"error", err)
258265
return
259266
}
260-
// Record actual TPOT
261-
entry := latencypredictor.TrainingEntry{
262-
KVCachePercentage: m.KVCacheUsagePercent,
263-
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
264-
ActualTTFT: 0,
265-
ActualTPOT: latencyMs,
266-
Timestamp: now,
267-
NumRequestWaiting: m.WaitingQueueSize,
268-
NumRequestRunning: m.RunningRequestsSize,
269-
NumTokensGenerated: predictedLatencyCtx.generatedTokenCount - 1,
270-
PrefixCacheScore: 0, // TPOT does not use prefix cache score
271-
}
267+
// Record actual TPOT using builder
268+
entry := requestBuilder.BuildTrainingEntry(
269+
ctx,
270+
pod,
271+
m,
272+
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
273+
0, // TTFT not recorded for TPOT
274+
latencyMs,
275+
now,
276+
predictedLatencyCtx.generatedTokenCount-1,
277+
0, // TPOT does not use prefix cache score
278+
)
272279
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
273280
logger.V(logutil.DEBUG).Error(err, "record TPOT training failed")
274281
}
@@ -312,16 +319,18 @@ func bulkPredictWithMetrics(
312319
ctx context.Context,
313320
predictor latencypredictor.PredictorInterface,
314321
metricsStates []*fwkdl.Metrics,
322+
requestBuilder PredictionRequestBuilder,
323+
pods []framework.Endpoint,
315324
prompts []string,
316325
generatedTokenCounts []int,
317326
prefixCacheScores []float64,
318327
) ([]*latencypredictor.PredictionResponse, error) {
319328
logger := log.FromContext(ctx)
320329

321330
// Validate input lengths
322-
if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) {
323-
return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d",
324-
len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores))
331+
if len(pods) != len(metricsStates) || len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) {
332+
return nil, fmt.Errorf("input slice lengths must match: pods=%d, metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d",
333+
len(pods), len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores))
325334
}
326335

327336
if len(metricsStates) == 0 {
@@ -335,17 +344,17 @@ func bulkPredictWithMetrics(
335344
}
336345
}
337346

338-
// Build bulk prediction requests
347+
// Build bulk prediction requests using the builder
339348
bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates))
340349
for i := range metricsStates {
341-
bulkRequests[i] = latencypredictor.PredictionRequest{
342-
KVCachePercentage: metricsStates[i].KVCacheUsagePercent,
343-
InputTokenLength: len(strings.Fields(prompts[i])),
344-
NumRequestWaiting: metricsStates[i].WaitingQueueSize,
345-
NumRequestRunning: metricsStates[i].RunningRequestsSize,
346-
NumTokensGenerated: generatedTokenCounts[i],
347-
PrefixCacheScore: prefixCacheScores[i],
348-
}
350+
bulkRequests[i] = requestBuilder.BuildPredictionRequest(
351+
ctx,
352+
pods[i],
353+
metricsStates[i],
354+
prompts[i],
355+
generatedTokenCounts[i],
356+
prefixCacheScores[i],
357+
)
349358
}
350359

351360
// Perform bulk prediction

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper_test.go

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ import (
2323
"testing"
2424

2525
"github.com/stretchr/testify/assert"
26+
"k8s.io/apimachinery/pkg/types"
27+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
2628
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
29+
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
2730
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
2831
)
2932

@@ -39,11 +42,24 @@ func TestBulkPredictWithMetrics(t *testing.T) {
3942
{KVCacheUsagePercent: 0.5},
4043
{KVCacheUsagePercent: 0.6},
4144
}
45+
requestBuilder := &DefaultPredictionRequestBuilder{}
46+
pods := []schedulingtypes.Endpoint{
47+
&schedulingtypes.PodMetrics{
48+
EndpointMetadata: &fwkdl.EndpointMetadata{
49+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
50+
},
51+
},
52+
&schedulingtypes.PodMetrics{
53+
EndpointMetadata: &fwkdl.EndpointMetadata{
54+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod2"},
55+
},
56+
},
57+
}
4258
prompts := []string{"prompt1", "prompt2"}
4359
generatedTokenCounts := []int{1, 1}
4460
prefixCacheScores := []float64{0.0, 0.0}
4561

46-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
62+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
4763

4864
assert.NoError(t, err)
4965
assert.Len(t, results, 2)
@@ -61,11 +77,19 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
6177
metricsStates := []*fwkdl.Metrics{
6278
{KVCacheUsagePercent: 0.5},
6379
}
80+
requestBuilder := &DefaultPredictionRequestBuilder{}
81+
pods := []schedulingtypes.Endpoint{
82+
&schedulingtypes.PodMetrics{
83+
EndpointMetadata: &fwkdl.EndpointMetadata{
84+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
85+
},
86+
},
87+
}
6488
prompts := []string{"prompt1"}
6589
generatedTokenCounts := []int{1}
6690
prefixCacheScores := []float64{0.0}
6791

68-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
92+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
6993

7094
assert.Error(t, err)
7195
assert.Nil(t, results)
@@ -74,11 +98,19 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
7498
func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
7599
mockPredictor := &mockPredictor{}
76100
metricsStates := []*fwkdl.Metrics{{}}
101+
requestBuilder := &DefaultPredictionRequestBuilder{}
102+
pods := []schedulingtypes.Endpoint{
103+
&schedulingtypes.PodMetrics{
104+
EndpointMetadata: &fwkdl.EndpointMetadata{
105+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
106+
},
107+
},
108+
}
77109
prompts := []string{"prompt1", "prompt2"} // Mismatch length
78110
generatedTokenCounts := []int{1}
79111
prefixCacheScores := []float64{0.0}
80112

81-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
113+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
82114

83115
assert.Error(t, err)
84116
assert.Nil(t, results)
@@ -88,11 +120,19 @@ func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
88120
func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) {
89121
mockPredictor := &mockPredictor{}
90122
metricsStates := []*fwkdl.Metrics{nil} // Nil metrics state
123+
requestBuilder := &DefaultPredictionRequestBuilder{}
124+
pods := []schedulingtypes.Endpoint{
125+
&schedulingtypes.PodMetrics{
126+
EndpointMetadata: &fwkdl.EndpointMetadata{
127+
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
128+
},
129+
},
130+
}
91131
prompts := []string{"prompt1"}
92132
generatedTokenCounts := []int{1}
93133
prefixCacheScores := []float64{0.0}
94134

95-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
135+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
96136

97137
assert.Error(t, err)
98138
assert.Nil(t, results)

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/prediction.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (s *PredictedLatency) generatePredictions(ctx context.Context, request *sch
6666
}
6767

6868
// Bulk predict
69-
bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
69+
bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, s.requestBuilder, candidateEndpoints, prompts, generatedTokenCounts, prefixCacheScores)
7070
if err != nil {
7171
logger.V(logutil.DEBUG).Error(err, "Bulk prediction failed")
7272
return nil, err

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks.go

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,57 @@ func (s *PredictedLatency) deletePredictedLatencyContextForRequest(request *sche
9898
s.sloContextStore.Delete(id)
9999
}
100100

101+
// GetSchedulingResultForRequest returns the scheduling result for a request.
102+
// This is exposed to allow wrapper implementations (e.g., P/D-aware routers)
103+
// to access scheduling information for custom hook logic.
104+
func (s *PredictedLatency) GetSchedulingResultForRequest(request *schedulingtypes.LLMRequest) (*schedulingtypes.SchedulingResult, error) {
105+
predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request)
106+
if err != nil {
107+
return nil, err
108+
}
109+
return predictedLatencyCtx.schedulingResult, nil
110+
}
111+
112+
// GetLastSeenMetricsForRequest returns the last seen metrics for all profiles in a request.
113+
// This is exposed to allow wrapper implementations to access metrics for custom training logic.
114+
func (s *PredictedLatency) GetLastSeenMetricsForRequest(request *schedulingtypes.LLMRequest) (map[string]*fwkdl.Metrics, error) {
115+
predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request)
116+
if err != nil {
117+
return nil, err
118+
}
119+
return predictedLatencyCtx.lastSeenMetrics, nil
120+
}
121+
122+
// GetPrefixCacheScoresForRequest returns the prefix cache scores for all pods in a request.
123+
func (s *PredictedLatency) GetPrefixCacheScoresForRequest(request *schedulingtypes.LLMRequest) (map[string]float64, error) {
124+
predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request)
125+
if err != nil {
126+
return nil, err
127+
}
128+
return predictedLatencyCtx.prefixCacheScoresForEndpoints, nil
129+
}
130+
131+
// GetRequestPrompt returns the prompt for a request.
132+
func (s *PredictedLatency) GetRequestPrompt(request *schedulingtypes.LLMRequest) (string, error) {
133+
predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request)
134+
if err != nil {
135+
return "", err
136+
}
137+
return predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt, nil
138+
}
139+
140+
// GetRequestBuilder returns the PredictionRequestBuilder used by this router.
141+
// This allows wrappers to use the same builder for consistency.
142+
func (s *PredictedLatency) GetRequestBuilder() PredictionRequestBuilder {
143+
return s.requestBuilder
144+
}
145+
146+
// GetLatencyPredictor returns the latency predictor client.
147+
// This allows wrappers to record training data using the same predictor.
148+
func (s *PredictedLatency) GetLatencyPredictor() interface{} {
149+
return s.latencypredictor
150+
}
151+
101152
// --- RequestControl Hooks ---
102153

103154
func (t *PredictedLatency) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) {
@@ -186,10 +237,16 @@ func (t *PredictedLatency) ResponseStreaming(ctx context.Context, request *sched
186237
return
187238
}
188239

240+
// Create a schedulingtypes.Endpoint wrapper for the metadata
241+
podWrapper := fwkdl.NewEndpoint(
242+
targetMetadata,
243+
predictedLatencyCtx.lastSeenMetrics[predictedLatencyCtx.schedulingResult.PrimaryProfileName],
244+
)
245+
189246
if predictedLatencyCtx.ttft == 0 {
190-
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
247+
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens)
191248
} else {
192-
processTokenForLatencyPrediction(ctx, t.latencypredictor, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
249+
processTokenForLatencyPrediction(ctx, t.latencypredictor, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens)
193250
}
194251

195252
}
@@ -213,7 +270,12 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
213270
}
214271
now := time.Now()
215272
if !t.config.StreamingMode {
216-
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
273+
// Create a schedulingtypes.Endpoint wrapper for non-streaming responses
274+
podWrapper := fwkdl.NewEndpoint(
275+
targetMetadata,
276+
predictedLatencyCtx.lastSeenMetrics[predictedLatencyCtx.schedulingResult.PrimaryProfileName],
277+
)
278+
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens)
217279
}
218280

219281
if predictedLatencyCtx.ttft > 0 {

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ func createTestRouter() *PredictedLatency {
6262
sloContextStore: sync.Map{},
6363
runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue),
6464
latencypredictor: nil,
65+
requestBuilder: &DefaultPredictionRequestBuilder{},
6566
config: DefaultConfig,
6667
}
6768
}

0 commit comments

Comments
 (0)