Skip to content

Commit b7a66bd

Browse files
committed
add scorer test
1 parent e2f9478 commit b7a66bd

File tree

6 files changed

+625
-32
lines changed

6 files changed

+625
-32
lines changed

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ import (
1919
"strings"
2020
)
2121

22+
var DefaultSamplingMean = func() float64 {
23+
if value, exists := os.LookupEnv("SAMPLING_MEAN"); exists {
24+
if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue > 0 {
25+
return parsedValue
26+
}
27+
}
28+
return 100.0 // default value
29+
}()
30+
31+
var MaxSampledTokens = func() int {
32+
if value, exists := os.LookupEnv("MAX_SAMPLED_TOKENS"); exists {
33+
if parsedValue, err := strconv.Atoi(value); err == nil && parsedValue > 0 {
34+
return parsedValue
35+
}
36+
}
37+
return 20 // default value
38+
}()
2239
var SLOBufferFactor = func() float64 {
2340
if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists {
2441
if parsedValue, err := strconv.ParseFloat(value, 64); err == nil {

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func (s *SLOAwareRouter) buildCompositeChoices(
133133
*total += w
134134
choices = append(choices, Choice{PodName: p.Pod, Weight: w})
135135

136-
log.FromContext(ctx).V(logutil.DEBUG).Info("Composite (neg/pos) score",
136+
log.FromContext(ctx).V(logutil.TRACE).Info("Composite (neg/pos) score",
137137
"pod", p.Pod.GetPod().String(),
138138
"kvUsage", kvUsage, "kvFree", kvFree,
139139
"queue", q, "relQueue", relQueue,

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,6 @@ import (
2828
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
2929
)
3030

31-
const (
32-
// Poisson sampling parameters for predictions
33-
defaultSamplingMean = 100 // Mean interval between prediction samples (tokens)
34-
maxSampledTokens = 20 // Maximum number of prediction samples per request
35-
)
36-
3731
// RefreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result.
3832
func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) {
3933
if sr := sloCtx.SchedulingResult; sr != nil {
@@ -136,7 +130,7 @@ func ProcessFirstTokenForLatencyPrediction(
136130
// Initialize sampler
137131
if sloCtx.TokenSampler == nil {
138132
requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey]
139-
sloCtx.TokenSampler = NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens)
133+
sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens)
140134
logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken())
141135
}
142136

@@ -214,7 +208,7 @@ func ProcessTokenForLatencyPrediction(
214208
// Initialize sampler if not yet
215209
if sloCtx.TokenSampler == nil {
216210
requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey]
217-
sloCtx.TokenSampler = NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens)
211+
sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens)
218212
logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken())
219213
}
220214

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,21 @@ import (
2222
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2323
)
2424

25+
type PodPredictionResult struct {
26+
Pod schedulingtypes.Pod
27+
TTFT float64
28+
TPOT float64
29+
TTFTValid bool
30+
TPOTValid bool
31+
IsValid bool
32+
Error error
33+
Headroom float64 // Headroom for the pod, if applicable
34+
TTFTHeadroom float64 // TTFT headroom for the pod
35+
PrefixCacheScore float64 // Prefix cache score for the pod
36+
}
37+
2538
// generatePredictions creates prediction results for all candidate pods
26-
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) []PodPredictionResult {
39+
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) ([]PodPredictionResult, error) {
2740
logger := log.FromContext(ctx)
2841
predictions := make([]PodPredictionResult, 0, len(candidatePods))
2942

@@ -42,10 +55,9 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul
4255
// Generate prediction
4356
prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore)
4457
if err != nil {
45-
logger.V(logutil.DEBUG).Info("Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err)
58+
logger.V(logutil.DEBUG).Error(err, "Skipping pod due to prediction error", "pod", pod.GetPod().String())
4659
predResult.Error = err
47-
predictions = append(predictions, predResult)
48-
continue
60+
return nil, err
4961
}
5062
predResult.PrefixCacheScore = prefixCacheScore
5163
predResult.TTFT = prediction.TTFT
@@ -76,7 +88,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul
7688
predictions = append(predictions, predResult)
7789
}
7890

79-
return predictions
91+
return predictions, nil
8092
}
8193

8294
// updateRequestContextWithPredictions updates the request context with prediction data

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,6 @@ import (
3535
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3636
)
3737

38-
type PodPredictionResult struct {
39-
Pod schedulingtypes.Pod
40-
TTFT float64
41-
TPOT float64
42-
TTFTValid bool
43-
TPOTValid bool
44-
IsValid bool
45-
Error error
46-
Headroom float64 // Headroom for the pod, if applicable
47-
TTFTHeadroom float64 // TTFT headroom for the pod
48-
PrefixCacheScore float64 // Prefix cache score for the pod
49-
}
50-
5138
type SLOAwareRouter struct {
5239
tn plugins.TypedName
5340
latencypredictor latencypredictor.PredictorInterface
@@ -126,6 +113,48 @@ func (s *SLOAwareRouter) epsilonGreedyAffinityGate(
126113
return eligible, true
127114
}
128115

116+
// scoreWithoutPredictions provides fallback scoring based only on prefix cache scores
117+
// when latency predictions are unavailable
118+
func (s *SLOAwareRouter) scoreWithoutPredictions(
119+
ctx context.Context,
120+
state *schedulingtypes.CycleState,
121+
pods []schedulingtypes.Pod,
122+
r *rand.Rand,
123+
) map[schedulingtypes.Pod]float64 {
124+
logger := log.FromContext(ctx)
125+
logger.V(logutil.TRACE).Info("Using composite-only scoring without predictions")
126+
127+
scores := make(map[schedulingtypes.Pod]float64, len(pods))
128+
for _, pod := range pods {
129+
scores[pod] = 0
130+
}
131+
132+
if len(pods) == 0 {
133+
return scores
134+
}
135+
136+
// Build prediction results with only prefix cache scores
137+
podResults := make([]PodPredictionResult, 0, len(pods))
138+
for _, pod := range pods {
139+
prefixScore := s.getPrefixCacheScoreForPod(ctx, state, pod)
140+
podResults = append(podResults, PodPredictionResult{
141+
Pod: pod,
142+
PrefixCacheScore: prefixScore,
143+
IsValid: true, // All pods are valid when we don't check predictions
144+
})
145+
}
146+
147+
// Select based on composite scores (prefix cache + other non-prediction metrics)
148+
selectedPod := s.selectFromCompositeScores(ctx, podResults, r, HeadroomStrategyCompositeOnly)
149+
150+
if selectedPod != nil {
151+
scores[selectedPod] = 1
152+
logger.V(logutil.TRACE).Info("Selected pod using composite-only scoring", "pod", selectedPod.GetPod().String())
153+
}
154+
155+
return scores
156+
}
157+
129158
func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 {
130159
logger := log.FromContext(ctx)
131160
if s.latencypredictor == nil {
@@ -158,11 +187,6 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle
158187
return nil
159188
}
160189

161-
predictions := s.generatePredictions(ctx, state, request, sloCtx, pods)
162-
s.updateRequestContextWithPredictions(sloCtx, predictions)
163-
164-
allPreds := append([]PodPredictionResult(nil), predictions...)
165-
166190
// Initialize scores map with all pods having score 0
167191
scores := make(map[schedulingtypes.Pod]float64, len(pods))
168192
for _, pod := range pods {
@@ -171,6 +195,17 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle
171195

172196
source := rand.NewSource(time.Now().UnixNano())
173197
r := rand.New(source)
198+
199+
predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods)
200+
if err != nil {
201+
logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring")
202+
// Fall back to composite-only scoring using prefix cache scores
203+
return s.scoreWithoutPredictions(ctx, state, pods, r)
204+
}
205+
s.updateRequestContextWithPredictions(sloCtx, predictions)
206+
207+
allPreds := append([]PodPredictionResult(nil), predictions...)
208+
174209
allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal)
175210

176211
// Check if all pods are invalid and all have running requests

0 commit comments

Comments
 (0)