diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 59c8976cd8..24743ccd38 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -63,6 +63,104 @@ var ( []string{"model_name", "target_model_name", "error_code"}, ) + // Gauge for various inference request metrics + inferenceGauges = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "inference_request_metric", + Help: metricsutil.HelpMsgWithStability("Consolidated gauge for various inference request metrics including TTFT, TPOT, SLOs, and prediction durations.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name", "type"}, + ) + + requestTTFT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3, + 4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTTFT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_ttft_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3, + 4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + // New metrics for TTFT prediction duration + requestTTFTPredictionDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_prediction_duration_seconds", + Help: metricsutil.HelpMsgWithStability("Duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0005, 0.00205, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.125, 0.15, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.8, 1, 1.5, 2, 3, 4.5, 6, 12, 18, 24, 30, 36, 48, 60, 90, 120, 180, 270, 360, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTPOT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_tpot_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0005, 0.00205, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.125, 0.15, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.8, 1, 1.5, 2, 3, 4.5, 6, 12, 18, 24, 30, 36, 48, 60, 90, 120, 180, 270, 360, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + // New metrics for TPOT prediction duration + requestTPOTPredictionDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_prediction_duration_seconds", + Help: metricsutil.HelpMsgWithStability("Duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + // Counter for SLO Violations + sloViolationCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_slo_violation_total", + Help: metricsutil.HelpMsgWithStability("Counter of SLO violations for each model, target model, and violation type.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name", "type"}, + ) + requestLatencies = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: InferenceObjectiveComponent, @@ -282,6 +380,21 @@ var registerMetrics sync.Once // Register all metrics. func Register(customCollectors ...prometheus.Collector) { registerMetrics.Do(func() { + // Register inference gauges + metrics.Registry.MustRegister(inferenceGauges) + + // Register Histograms + metrics.Registry.MustRegister(requestTPOT) + metrics.Registry.MustRegister(requestTTFT) + metrics.Registry.MustRegister(requestPredictedTPOT) + metrics.Registry.MustRegister(requestPredictedTTFT) + metrics.Registry.MustRegister(requestTPOTPredictionDuration) + metrics.Registry.MustRegister(requestTTFTPredictionDuration) + + // Register SLO violation counters + metrics.Registry.MustRegister(sloViolationCounter) + + // Register other metrics metrics.Registry.MustRegister(requestCounter) metrics.Registry.MustRegister(requestErrCounter) metrics.Registry.MustRegister(requestLatencies) @@ -311,6 +424,21 @@ func Register(customCollectors ...prometheus.Collector) { // Just for integration test func Reset() { + // Reset inference gauges + inferenceGauges.Reset() + + // Reset Histograms + requestTPOT.Reset() + requestTTFT.Reset() + requestPredictedTPOT.Reset() + requestPredictedTTFT.Reset() + requestTPOTPredictionDuration.Reset() + requestTTFTPredictionDuration.Reset() + + // Reset SLO violation counter + sloViolationCounter.Reset() + + // Reset other metrics requestCounter.Reset() requestErrCounter.Reset() requestLatencies.Reset() @@ -363,6 +491,123 @@ func RecordRequestLatencies(ctx context.Context, modelName, targetModelName stri return true } +func RecordRequestTPOT(ctx context.Context, modelName, targetModelName string, tpot float64) bool { + if tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot) + return false + } + requestTPOT.WithLabelValues(modelName, targetModelName).Observe(tpot) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot"}).Set(tpot) + return true +} + +// RecordRequestTPOTWithSLO records TPOT and checks for SLO violation. +// If tpot exceeds the threshold, it records a violation (sets gauge to 1 and increments counter). +// If tpot is within limits, it sets gauge to 0. +func RecordRequestTPOTWithSLO(ctx context.Context, modelName, targetModelName string, tpot float64, sloThreshold float64) bool { + if tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot) + return false + } + + // Check for SLO violation (tpot exceeds threshold) + if tpot > sloThreshold { + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot_slo_violation"}).Set(1) + sloViolationCounter.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot"}).Inc() + log.FromContext(ctx).V(logutil.DEFAULT).Info("TPOT SLO violation detected", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot, "threshold", sloThreshold) + } else { + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot_slo_violation"}).Set(0) + } + + return true +} + +// TPOT records duration of request. +func RecordRequestPredictedTPOT(ctx context.Context, modelName, targetModelName string, predicted_tpot float64) bool { + if predicted_tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Predicted TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", predicted_tpot) + return false + } + requestPredictedTPOT.WithLabelValues(modelName, targetModelName).Observe(predicted_tpot) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "predicted_tpot"}).Set(predicted_tpot) + return true +} + +// RecordRequestTPOTPredictionDuration records the duration taken to generate TPOT predictions. +func RecordRequestTPOTPredictionDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool { + if duration < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT prediction duration must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "duration", duration) + return false + } + requestTPOTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot_prediction_duration"}).Set(duration) + return true +} + +// TTFT records duration of request. +func RecordRequestTTFT(ctx context.Context, modelName, targetModelName string, ttft float64) bool { + if ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft) + return false + } + requestTTFT.WithLabelValues(modelName, targetModelName).Observe(ttft) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft"}).Set(ttft) + return true +} + +// RecordRequestTTFTWithSLO records TTFT and checks for SLO violation. +// If ttft exceeds the threshold, it records a violation (sets gauge to 1 and increments counter). +// If ttft is within limits, it sets gauge to 0. +func RecordRequestTTFTWithSLO(ctx context.Context, modelName, targetModelName string, ttft float64, sloThreshold float64) bool { + if ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft) + return false + } + + // Check for SLO violation (ttft exceeds threshold) + if ttft > sloThreshold { + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft_slo_violation"}).Set(1) + sloViolationCounter.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft"}).Inc() + log.FromContext(ctx).V(logutil.DEFAULT).Info("TTFT SLO violation detected", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft, "threshold", sloThreshold) + } else { + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft_slo_violation"}).Set(0) + } + + return true +} + +// TPOT records duration of request. +func RecordRequestPredictedTTFT(ctx context.Context, modelName, targetModelName string, predicted_ttft float64) bool { + if predicted_ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Predicted TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", predicted_ttft) + return false + } + requestPredictedTTFT.WithLabelValues(modelName, targetModelName).Observe(predicted_ttft) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "predicted_ttft"}).Set(predicted_ttft) + return true +} + +// RecordRequestTTFTPredictionDuration records the duration taken to generate TTFT predictions. +func RecordRequestTTFTPredictionDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool { + if duration < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT prediction duration must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "duration", duration) + return false + } + requestTTFTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft_prediction_duration"}).Set(duration) + return true +} + // RecordResponseSizes records the response sizes. func RecordResponseSizes(modelName, targetModelName string, size int) { responseSizes.WithLabelValues(modelName, targetModelName).Observe(float64(size)) @@ -480,3 +725,15 @@ func IncFlowControlQueueSize(fairnessID, priority string) { func DecFlowControlQueueSize(fairnessID, priority string) { flowControlQueueSize.WithLabelValues(fairnessID, priority).Dec() } + +// SetTTFTSLOThreshold sets the TTFT SLO threshold for a model. +// This allows dynamic threshold management and makes the threshold visible in metrics. +func SetTTFTSLOThreshold(modelName, targetModelName string, threshold float64) { + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft_slo_threshold"}).Set(threshold) +} + +// SetTPOTSLOThreshold sets the TPOT SLO threshold for a model. +// This allows dynamic threshold management and makes the threshold visible in metrics. +func SetTPOTSLOThreshold(modelName, targetModelName string, threshold float64) { + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot_slo_threshold"}).Set(threshold) +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index 7d41681830..754d6d2947 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -46,6 +46,8 @@ const ( KVCacheAvgUsageMetric = InferencePoolComponent + "_average_kv_cache_utilization" QueueAvgSizeMetric = InferencePoolComponent + "_average_queue_size" PerPodQueueSizeMetrics = InferencePoolComponent + "_per_pod_queue_size" + RequestTTFTSecondsMetric = InferenceObjectiveComponent + "_request_ttft_seconds" + RequestTPOTSecondsMetric = InferenceObjectiveComponent + "_request_tpot_seconds" ) func TestMain(m *testing.M) { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go new file mode 100644 index 0000000000..bbfd772232 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go @@ -0,0 +1,191 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "os" + "strconv" + "strings" +) + +var DefaultSamplingMean = func() float64 { + if value, exists := os.LookupEnv("SAMPLING_MEAN"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 100.0 // default value +}() + +var MaxSampledTokens = func() int { + if value, exists := os.LookupEnv("MAX_SAMPLED_TOKENS"); exists { + if parsedValue, err := strconv.Atoi(value); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 20 // default value +}() + +var SLOBufferFactor = func() float64 { + if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil { + return parsedValue + } + } + return 1.0 // default value +}() + +var NegHeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default: TTFT dominates when violating SLOs +}() + +var NegHeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default: TPOT less important in your tiny-output scenario +}() + +var HeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default +}() + +var HeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default +}() + +var HeadroomSelectionStrategy = func() headroomStrategy { + if value, exists := os.LookupEnv("HEADROOM_SELECTION_STRATEGY"); exists { + switch strings.ToLower(value) { + case "least": + return headroomStrategyLeast + case "most": + return headroomStrategyMost + case "composite-least": + return headroomStrategyCompositeLeast + case "composite-most": + return headroomStrategyCompositeMost + case "composite-only": + return headroomStrategyCompositeOnly + } + } + return headroomStrategyLeast // default to least (better packing) +}() + +// If using composite headroom, weights for each component. Not used by default +var CompositeKVWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_KV_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositeQueueWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_QUEUE_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositePrefixWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_PREFIX_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +// With probability ε, explore (ignore affinity gate); otherwise exploit. +var EpsilonExploreSticky = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("STICKY_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +var EpsilonExploreNeg = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("NEG_HEADROOM_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +// τ for per-path affinity gate (aka "stickiness" threshold). +var AffinityGateTau = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.80 +}() + +// Global τ for the overall candidate set (previously "overall stickiness"). +var AffinityGateTauGlobal = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU_GLOBAL"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.99 +}() + +// Read once at init. Values: "linear" (default) or "max". +var SelectionMode = func() podSelectionMode { + if v, ok := os.LookupEnv("POD_SELECTION_MODE"); ok { + switch strings.ToLower(v) { + case "max": + return podSelectionMax + case "linear": + fallthrough + default: + return podSelectionLinear + } + } + return podSelectionLinear +}() diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go new file mode 100644 index 0000000000..da0a86f200 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go @@ -0,0 +1,69 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "strconv" + + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" +) + +// parseFloatHeader retrieves a header by name, parses it as a float64, +// and returns the value or an error if the header is missing or invalid. +func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return 0, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a float64 + parsedFloat, err := strconv.ParseFloat(headerValue, 64) + if err != nil { + return 0, errutil.Error{ + Code: errutil.BadRequest, + Msg: headerName + " must be a float", + } + } + + // 3. Return the successfully parsed value + return parsedFloat, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request schedulingtypes.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, errutil.Error{ + Code: errutil.BadRequest, + Msg: headerName + " must be a bool", + } + } + + // 3. Return the successfully parsed value + return parsedBool, nil +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go new file mode 100644 index 0000000000..3b5820610e --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go @@ -0,0 +1,145 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []podPredictionResult, r *rand.Rand, strategy headroomStrategy) schedulingtypes.Pod { + total := 0 + choices := s.buildCompositeChoices( + ctx, allPreds, CompositeKVWeight, CompositeQueueWeight, CompositePrefixWeight, &total, + ) + if strategy == headroomStrategyCompositeLeast { + // Invert weights for "least" strategy + for i := range choices { + choices[i].weight = minWeight + wMax - choices[i].weight + } + } + selectedPod := s.performWeightedRandomSelection(choices, total, allPreds, r) + return selectedPod +} +func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []choice, total int, candidates []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if total == 0 { + return nil + } + logger := log.FromContext(context.Background()) + // Check if MAX_SCORE_SELECTION env variable is set + if SelectionMode == podSelectionMax { + + logger.V(logutil.DEBUG).Info("Pod selection mode: MAX - selecting pod with highest weight") + maxWeight := 0 + var selectedPod schedulingtypes.Pod + for _, c := range weightedChoices { + if c.weight > maxWeight { + maxWeight = c.weight + selectedPod = c.podName + } + } + if selectedPod != nil { + return selectedPod + } + // Fallback to first pod if no selection made + return candidates[0].Pod + } + + // Original weighted random selection logic + logger.V(logutil.DEBUG).Info("Pod selection mode: LINEAR - performing weighted random selection") + idx := r.Intn(total) + var selectedPod schedulingtypes.Pod + + for _, c := range weightedChoices { + if idx < c.weight { + selectedPod = c.podName + break + } + idx -= c.weight + } + + // If no pod was selected (shouldn't happen), fallback to first pod + if selectedPod == nil { + selectedPod = candidates[0].Pod + } + + return selectedPod +} +func (s *SLOAwareRouter) buildCompositeChoices( + ctx context.Context, + candidates []podPredictionResult, + wkv, wq, wpref float64, + total *int, +) []choice { + + // Normalize weights + sumw := wkv + wq + wpref + if sumw <= 0 { + wkv, wq, wpref = 1, 0, 0 + } else { + wkv /= sumw + wq /= sumw + wpref /= sumw + } + + // Precompute queue stats + minQ, maxQ := math.MaxInt32, -1 + queueCounts := make(map[string]int, len(candidates)) + for _, p := range candidates { + q := p.Pod.GetMetrics().WaitingQueueSize + queueCounts[p.Pod.GetPod().String()] = q + if q < minQ { + minQ = q + } + if q > maxQ { + maxQ = q + } + } + den := float64(maxQ - minQ) + + choices := make([]choice, 0, len(candidates)) + for _, p := range candidates { + q := queueCounts[p.Pod.GetPod().String()] + relQueue := 1.0 + if den > 0 { + relQueue = (float64(maxQ-q) / den) + } + + kvUsage := p.Pod.GetMetrics().KVCacheUsagePercent + kvFree := (1.0 - kvUsage) + prefix := (p.PrefixCacheScore) + + composite := wkv*kvFree + wq*relQueue + wpref*prefix + w := int(math.Round(float64(minWeight) + (float64(wMax-minWeight) * composite))) + *total += w + choices = append(choices, choice{podName: p.Pod, weight: w}) + + log.FromContext(ctx).V(logutil.TRACE).Info("Composite (neg/pos) score", + "pod", p.Pod.GetPod().String(), + "kvUsage", kvUsage, "kvFree", kvFree, + "queue", q, "relQueue", relQueue, + "prefix", prefix, + "wkv", wkv, "wq", wq, "wprefix", wpref, + "composite", composite, "weight", w) + } + return choices +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go new file mode 100644 index 0000000000..7482de93c7 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -0,0 +1,405 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +// refreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result. +func refreshLastSeenMetrics(ctx context.Context, sloCtx *sloRequestContext) { + if sr := sloCtx.schedulingResult; sr != nil { + if pr := sr.ProfileResults[sr.PrimaryProfileName]; pr != nil && pr.TargetPods != nil { + for profileName, profileResult := range sr.ProfileResults { + if profileResult != nil && profileResult.TargetPods != nil && len(profileResult.TargetPods) > 0 { + sloCtx.lastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone() + } + } + } + } else { + log.FromContext(ctx).V(logutil.DEBUG).Info("No scheduling result found, skipping metrics refresh") + } +} + +// GetMetricsForPrediction retrieves the latest metrics for prediction from sloCtx.LastSeenMetrics. +func getLatestMetricsForProfile(sloCtx *sloRequestContext) (*backendmetrics.MetricsState, error) { + if len(sloCtx.lastSeenMetrics) == 0 { + return nil, errors.New("no last seen metrics available for prediction") + } + + primaryProfileName := sloCtx.schedulingResult.PrimaryProfileName + if metrics, exists := sloCtx.lastSeenMetrics[primaryProfileName]; exists { + return metrics, nil + } + + return nil, fmt.Errorf("no metrics found for primary profile %s", primaryProfileName) +} + +// ProcessHeader refreshes metrics, applies TTFT prediction, updates sloCtx.PredictedTTFT and timestamp. +func processHeaderForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *sloRequestContext, +) error { + logger := log.FromContext(ctx) + + // just for debugging, print the req context scheduling result cycle state + // print the raw scores in scheduling result + + // Build prediction request + m, err := getLatestMetricsForProfile(sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) + return err + } + + targetPod := sloCtx.targetPod + prefix_cache_score := sloCtx.prefixCacheScoresForPods[targetPod.String()] + + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: 0, + PrefixCacheScore: prefix_cache_score, + } + + // Predict TTFT + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + switch { + case err != nil: + logger.V(logutil.DEBUG).Error(err, "header TTFT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.predictedTTFT = 0 + case p == nil: + logger.V(logutil.DEBUG).Info("header TTFT predict nil", "duration_ms", dur.Milliseconds()) + sloCtx.predictedTTFT = 0 + default: + logger.V(logutil.DEBUG).Info("header TTFT succeeded", "value_ms", p.TTFT, "duration_ms", dur.Milliseconds()) + metrics.RecordRequestTTFTPredictionDuration(ctx, sloCtx.schedulingRequest.TargetModel, sloCtx.incomingModelName, dur.Seconds()) + + sloCtx.predictedTTFT = p.TTFT + } + + // Advance timestamp for first token reference + sloCtx.lastTokenTimestamp = time.Now() + refreshLastSeenMetrics(ctx, sloCtx) + return err +} + +// ProcessFirstToken records actual TTFT, trains, predicts first TPOT, updates sloCtx, and advances timestamp. +func processFirstTokenForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *sloRequestContext, + now time.Time, +) { + logger := log.FromContext(ctx) + + initializeSampler(ctx, sloCtx) + + // Actual TTFT + sloCtx.ttft = float64(now.Sub(sloCtx.requestReceivedTimestamp).Milliseconds()) + sloCtx.generatedTokenCount = 1 + m, err := getLatestMetricsForProfile(sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) + return + } + targetPod := sloCtx.targetPod + prefixCacheScore := sloCtx.prefixCacheScoresForPods[targetPod.String()] + + recordTTFTTrainingData(ctx, predictor, sloCtx, m, now, prefixCacheScore) + + predictFirstTPOT(ctx, predictor, sloCtx) + + // Advance timestamp + sloCtx.lastTokenTimestamp = now + // Refresh metrics + refreshLastSeenMetrics(ctx, sloCtx) +} + +func initializeSampler(ctx context.Context, sloCtx *sloRequestContext) { + if sloCtx.tokenSampler == nil { + logger := log.FromContext(ctx) + requestID := sloCtx.schedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.tokenSampler = newTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.tokenSampler.getNextSampleToken()) + } +} + +func recordTTFTTrainingData( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *sloRequestContext, + m *backendmetrics.MetricsState, + now time.Time, + prefixCacheScore float64, +) { + logger := log.FromContext(ctx) + // Train TTFT + entry := latencypredictor.TrainingEntry{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), + ActualTTFT: sloCtx.ttft, + ActualTPOT: 0, + Timestamp: now, + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: 0, + PrefixCacheScore: prefixCacheScore, + } + if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { + logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") + } +} + +func predictFirstTPOT( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *sloRequestContext, +) { + logger := log.FromContext(ctx) + m, err := getLatestMetricsForProfile(sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", + "error", err) + return + } + + // Predict first TPOT + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.generatedTokenCount, + PrefixCacheScore: 0, + } + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil || p == nil { + logger.V(logutil.DEBUG).Error(err, "first TPOT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.predictedTPOTObservations = append(sloCtx.predictedTPOTObservations, 0) + sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, 0, len(sloCtx.predictedTPOTObservations)) + } else { + logger.V(logutil.DEBUG).Info("first TPOT succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) + sloCtx.predictedTPOTObservations = append(sloCtx.predictedTPOTObservations, p.TPOT) + sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, p.TPOT, len(sloCtx.predictedTPOTObservations)) + } + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.schedulingRequest.TargetModel, sloCtx.incomingModelName, dur.Seconds()) +} + +// ProcessToken records actual inter-token latency, trains, predicts sampled TPOT, updates sloCtx, and advances timestamp. +func processTokenForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *sloRequestContext, + now time.Time, +) { + logger := log.FromContext(ctx) + + // Initialize sampler if not yet + if sloCtx.tokenSampler == nil { + requestID := sloCtx.schedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.tokenSampler = newTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.tokenSampler.getNextSampleToken()) + } + + // Inter-token latency + latencyMs := float64(now.Sub(sloCtx.lastTokenTimestamp).Milliseconds()) + sloCtx.generatedTokenCount++ + + // log the inter-token latency for predicted samples + if sloCtx.generatedTokenCount == 2 || sloCtx.tokenSampler.shouldPredict(sloCtx.generatedTokenCount) { // tricky logic, since next sample token is always +1 from current token + sloCtx.tpotObservations = append(sloCtx.tpotObservations, latencyMs) + sloCtx.avgTPOT = calculateRunningAverage(sloCtx.avgTPOT, latencyMs, len(sloCtx.tpotObservations)) + } + + m, err := getLatestMetricsForProfile(sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", + "error", err) + return + } + // Record actual TPOT + entry := latencypredictor.TrainingEntry{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), + ActualTTFT: 0, + ActualTPOT: latencyMs, + Timestamp: now, + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.generatedTokenCount - 1, + PrefixCacheScore: 0, // TPOT does not use prefix cache score + } + if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { + logger.V(logutil.DEBUG).Error(err, "record TPOT training failed") + } + + // Sampled predict + if sloCtx.tokenSampler.shouldPredict(sloCtx.generatedTokenCount) { + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.generatedTokenCount, + PrefixCacheScore: 0, // TPOT does not use prefix cache score + } + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil || p == nil { + logger.V(logutil.DEBUG).Error(err, "TPOT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.predictedTPOTObservations = append(sloCtx.predictedTPOTObservations, 0) + sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, 0, len(sloCtx.predictedTPOTObservations)) + } else { + logger.V(logutil.DEBUG).Info("TPOT predict succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) + sloCtx.predictedTPOTObservations = append(sloCtx.predictedTPOTObservations, p.TPOT) + sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, p.TPOT, len(sloCtx.predictedTPOTObservations)) + } + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.schedulingRequest.TargetModel, sloCtx.incomingModelName, dur.Seconds()) + + sloCtx.tokenSampler.recordPrediction(sloCtx.generatedTokenCount) + } + + // Advance timestamp + sloCtx.lastTokenTimestamp = now + // Refresh metrics + refreshLastSeenMetrics(ctx, sloCtx) +} + +// bulkPredictWithMetrics performs bulk predictions for multiple pods using their metrics states. +// Returns predictions in the same order as the input slices. +func bulkPredictWithMetrics( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + metricsStates []*backendmetrics.MetricsState, + prompts []string, + generatedTokenCounts []int, + prefixCacheScores []float64, +) ([]*latencypredictor.PredictionResponse, error) { + logger := log.FromContext(ctx) + + // Validate input lengths + if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) { + return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d", + len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores)) + } + + if len(metricsStates) == 0 { + return []*latencypredictor.PredictionResponse{}, nil + } + + // Validate that no metrics state is nil + for i, metricsState := range metricsStates { + if metricsState == nil { + return nil, fmt.Errorf("metrics state at index %d cannot be nil", i) + } + } + + // Build bulk prediction requests + bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates)) + for i := range metricsStates { + bulkRequests[i] = latencypredictor.PredictionRequest{ + KVCachePercentage: metricsStates[i].KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompts[i])), + NumRequestWaiting: metricsStates[i].WaitingQueueSize, + NumRequestRunning: metricsStates[i].RunningQueueSize, + NumTokensGenerated: generatedTokenCounts[i], + PrefixCacheScore: prefixCacheScores[i], + } + } + + // Perform bulk prediction + start := time.Now() + bulkResponse, err := predictor.PredictBulkStrict(ctx, bulkRequests) + duration := time.Since(start) + + if err != nil { + logger.V(logutil.DEBUG).Error(err, "bulk prediction failed", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests)) + return nil, err + } + + if bulkResponse == nil { + logger.V(logutil.DEBUG).Info("bulk prediction returned nil", + "duration_ms", duration.Milliseconds()) + return nil, errors.New("bulk prediction returned nil result") + } + + // Convert to pointer slice for consistency with single prediction + results := make([]*latencypredictor.PredictionResponse, len(bulkResponse.Predictions)) + for i := range bulkResponse.Predictions { + results[i] = &bulkResponse.Predictions[i] + } + + logger.V(logutil.DEBUG).Info("bulk prediction succeeded", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests), + "successful_predictions", bulkResponse.SuccessfulPredictions, + "failed_predictions", bulkResponse.FailedPredictions, + "processing_time_ms", bulkResponse.ProcessingTimeMs) + + // Log detailed results if at trace level + if logger.V(logutil.TRACE).Enabled() { + for i, result := range results { + logger.V(logutil.TRACE).Info("bulk prediction result", + "index", i, + "ttft_ms", result.TTFT, + "tpot_ms", result.TPOT, + "input_tokens", bulkRequests[i].InputTokenLength, + "generated_tokens", bulkRequests[i].NumTokensGenerated, + "kv_cache_percent", bulkRequests[i].KVCachePercentage, + "waiting_queue", bulkRequests[i].NumRequestWaiting, + "running_queue", bulkRequests[i].NumRequestRunning, + "prefix_cache_score", bulkRequests[i].PrefixCacheScore) + } + } + + return results, nil +} + +// calculateRunningAverage calculates the running average efficiently +func calculateRunningAverage(currentAvg float64, newValue float64, count int) float64 { + if count == 0 { + return 0 + } + if count == 1 { + return newValue + } + return currentAvg + (newValue-currentAvg)/float64(count) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper_test.go new file mode 100644 index 0000000000..92227cba62 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper_test.go @@ -0,0 +1,100 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +func TestBulkPredictWithMetrics(t *testing.T) { + mockPredictor := &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, + "0.6": {TTFT: 0.6, TPOT: 0.04}, + }, + } + + metricsStates := []*backendmetrics.MetricsState{ + {KVCacheUsagePercent: 0.5}, + {KVCacheUsagePercent: 0.6}, + } + prompts := []string{"prompt1", "prompt2"} + generatedTokenCounts := []int{1, 1} + prefixCacheScores := []float64{0.0, 0.0} + + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + + assert.NoError(t, err) + assert.Len(t, results, 2) + assert.Equal(t, 0.5, results[0].TTFT) + assert.Equal(t, 0.03, results[0].TPOT) + assert.Equal(t, 0.6, results[1].TTFT) + assert.Equal(t, 0.04, results[1].TPOT) +} + +func TestBulkPredictWithMetrics_Error(t *testing.T) { + mockPredictor := &mockPredictor{ + err: errors.New("prediction failed"), + } + + metricsStates := []*backendmetrics.MetricsState{ + {KVCacheUsagePercent: 0.5}, + } + prompts := []string{"prompt1"} + generatedTokenCounts := []int{1} + prefixCacheScores := []float64{0.0} + + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + + assert.Error(t, err) + assert.Nil(t, results) +} + +func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) { + mockPredictor := &mockPredictor{} + metricsStates := []*backendmetrics.MetricsState{{}} + prompts := []string{"prompt1", "prompt2"} // Mismatch length + generatedTokenCounts := []int{1} + prefixCacheScores := []float64{0.0} + + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + + assert.Error(t, err) + assert.Nil(t, results) + assert.True(t, strings.Contains(err.Error(), "input slice lengths must match")) +} + +func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) { + mockPredictor := &mockPredictor{} + metricsStates := []*backendmetrics.MetricsState{nil} // Nil metrics state + prompts := []string{"prompt1"} + generatedTokenCounts := []int{1} + prefixCacheScores := []float64{0.0} + + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + + assert.Error(t, err) + assert.Nil(t, results) + assert.True(t, strings.Contains(err.Error(), "metrics state at index 0 cannot be nil")) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go new file mode 100644 index 0000000000..7d41de95c9 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -0,0 +1,148 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +type podPredictionResult struct { + Pod schedulingtypes.Pod + TTFT float64 + TPOT float64 + TTFTValid bool + TPOTValid bool + IsValid bool + Error error + Headroom float64 // Headroom for the pod, if applicable + TTFTHeadroom float64 // TTFT headroom for the pod + PrefixCacheScore float64 // Prefix cache score for the pod +} + +// generatePredictions creates prediction results for all candidate pods +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) { + logger := log.FromContext(ctx) + predictions := make([]podPredictionResult, 0, len(candidatePods)) + + // Prepare inputs for bulk prediction + metricsStates := make([]*backendmetrics.MetricsState, len(candidatePods)) + prompts := make([]string, len(candidatePods)) + generatedTokenCounts := make([]int, len(candidatePods)) + prefixCacheScores := make([]float64, len(candidatePods)) + + for i, pod := range candidatePods { + logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) + + // Get prefix cache score for the pod + prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + + logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) + + metricsStates[i] = pod.GetMetrics() + prompts[i] = request.Body.Completions.Prompt + generatedTokenCounts[i] = 1 + prefixCacheScores[i] = prefixCacheScore + } + + // Bulk predict + bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "Bulk prediction failed") + return nil, err + } + + // Process results + for i, pod := range candidatePods { + prediction := bulkPredictions[i] + predResult := podPredictionResult{Pod: pod} + + predResult.PrefixCacheScore = prefixCacheScores[i] + predResult.TTFT = prediction.TTFT + predResult.TPOT = prediction.TPOT + + podMinTPOTSLO := s.getPodMinTPOTSLO(pod) + predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) + + logger.V(logutil.DEBUG).Info("Prediction for scheduling", + "pod", pod.GetPod().String(), + "prefixCacheScore", predResult.PrefixCacheScore, + "TTFT", prediction.TTFT, + "TPOT", prediction.TPOT, + "buffer", SLOBufferFactor, + "podMinTPOTSLO", podMinTPOTSLO, + "ttftSLO", sloCtx.ttftSLO, + "requestTPOTSLO", sloCtx.avgTPOTSLO, + "tpotHeadroom", predResult.Headroom, + "ttftHeadroom", predResult.TTFTHeadroom, + "tpotValid", predResult.TPOTValid, + "ttftValid", predResult.TTFTValid, + "headroomStrategy", s.headroomStrategy) + + predictions = append(predictions, predResult) + } + + return predictions, nil +} + +// updateRequestContextWithPredictions updates the request context with prediction data +func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *sloRequestContext, predictions []podPredictionResult) { + for _, pred := range predictions { + if pred.Error == nil { + podKey := pred.Pod.GetPod().String() + if sloCtx.predictedTTFTForScheduling == nil { + sloCtx.predictedTTFTForScheduling = make(map[string]float64) + } + if sloCtx.predictedTPOTForScheduling == nil { + sloCtx.predictedTPOTForScheduling = make(map[string]float64) + } + sloCtx.predictedTTFTForScheduling[podKey] = pred.TTFT + sloCtx.predictedTPOTForScheduling[podKey] = pred.TPOT + } + } +} + +func (s *SLOAwareRouter) validatePrediction( + pred *latencypredictor.PredictionResponse, + sloCtx *sloRequestContext, + podMinTPOTSLO float64, +) (ttftOk, tpotOk, isValid bool, headroom float64, ttftHeadroom float64) { + + bufferedTPOT := sloCtx.avgTPOTSLO * SLOBufferFactor + // a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests + if podMinTPOTSLO > 0 { + if podMinTPOTSLO < sloCtx.avgTPOTSLO { + log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.avgTPOTSLO) + } + bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*SLOBufferFactor) + } + + tpotOk = pred.TPOT < bufferedTPOT + ttftOk = pred.TTFT < sloCtx.ttftSLO + + isValid = ttftOk && tpotOk + headroom = bufferedTPOT - pred.TPOT + ttftHeadroom = sloCtx.ttftSLO - pred.TTFT + return +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go new file mode 100644 index 0000000000..d505b11e38 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -0,0 +1,263 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" +) + +var _ requestcontrol.PreRequest = &SLOAwareRouter{} +var _ requestcontrol.ResponseReceived = &SLOAwareRouter{} +var _ requestcontrol.ResponseStreaming = &SLOAwareRouter{} +var _ requestcontrol.ResponseComplete = &SLOAwareRouter{} + +type sloRequestContext struct { + schedulingRequest schedulingtypes.LLMRequest + targetPod *backend.Pod + schedulingResult *schedulingtypes.SchedulingResult + lastSeenMetrics map[string]*backendmetrics.MetricsState + lastTokenTimestamp time.Time + requestReceivedTimestamp time.Time + generatedTokenCount int + incomingModelName string + ttft float64 + predictedTTFT float64 + avgTPOT float64 + avgPredictedTPOT float64 + tokenSampler *tokenSampler + tpotObservations []float64 + predictedTPOTObservations []float64 + + prefixCacheScoresForPods map[string]float64 + + // ttftSLO is the target time to first token SLO for the request. + ttftSLO float64 + // TPOTSLO is the target time per output token SLO for the request. + avgTPOTSLO float64 + + // predictorBasedScheduling indicates whether to use predictor based scheduling. + predictorBasedScheduling bool + // predictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. + predictedTTFTForScheduling map[string]float64 + // predictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. + predictedTPOTForScheduling map[string]float64 + + // boolean set if request has valid pod based on predictions + hasValidPod bool +} + +func newSLORequestContext(request *schedulingtypes.LLMRequest) *sloRequestContext { + return &sloRequestContext{ + schedulingRequest: *request, + lastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + prefixCacheScoresForPods: make(map[string]float64), + predictedTTFTForScheduling: make(map[string]float64), + predictedTPOTForScheduling: make(map[string]float64), + } +} + +func (s *SLOAwareRouter) getSLOContextForRequest(request *schedulingtypes.LLMRequest) (*sloRequestContext, error) { + id := request.Headers[requtil.RequestIdHeaderKey] + if ctx, exists := s.sloContextStore.Load(id); exists { + return ctx.(*sloRequestContext), nil + } + return nil, fmt.Errorf("SLO context not found for request ID: %s", id) +} + +func (s *SLOAwareRouter) setSLOContextForRequest(request *schedulingtypes.LLMRequest, ctx *sloRequestContext) { + id := request.Headers[requtil.RequestIdHeaderKey] + s.sloContextStore.Store(id, ctx) +} + +func (s *SLOAwareRouter) deleteSLOContextForRequest(request *schedulingtypes.LLMRequest) { + id := request.Headers[requtil.RequestIdHeaderKey] + s.sloContextStore.Delete(id) +} + +// --- RequestControl Hooks --- + +func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) { + logger := log.FromContext(ctx) + + if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.") + return + } + + targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod() + if !t.checkPredictor(logger, targetPod) { + return + } + + podName := types.NamespacedName{ + Name: targetPod.NamespacedName.Name, + Namespace: targetPod.NamespacedName.Namespace, + } + + logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName) + if request.Headers[requtil.RequestIdHeaderKey] == "" { + logger.V(logutil.DEBUG).Error(errors.New("missing request ID"), "SLOAwareRouter.PreRequest: Request is missing request ID header") + } + + id := request.Headers[requtil.RequestIdHeaderKey] + podRequestList, ok := t.runningRequestLists[podName] + if !ok { + podRequestList = newRequestPriorityQueue() + t.runningRequestLists[podName] = podRequestList + } + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.PreRequest: Failed to get SLO context for request", "requestID", id) + return + } + + added := podRequestList.Add(id, sloCtx.avgTPOTSLO) + if !added { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id) + } + + // Set up SLO request context + sloCtx.targetPod = targetPod + sloCtx.schedulingResult = schedulingResult + sloCtx.requestReceivedTimestamp = time.Now() + refreshLastSeenMetrics(ctx, sloCtx) + t.setSLOContextForRequest(request, sloCtx) +} + +func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { + logger := log.FromContext(ctx) + if !t.checkPredictor(logger, targetPod) { + return + } + + id := request.Headers[requtil.RequestIdHeaderKey] + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to get SLO context for request", "requestID", id) + return + } + + if err := processHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil { + logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed") + } + +} + +func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { + logger := log.FromContext(ctx) + if !t.checkPredictor(logger, pod) || response.EndOfStream { + return + } + + now := time.Now() + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.TRACE).Error(err, "SLOAwareRouter.ResponseStreaming: Failed to get SLO context for request", "requestID", id) + return + } + + if sloCtx.ttft == 0 { + processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + } else { + processTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + } + +} + +func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { + logger := log.FromContext(ctx) + targetPod := pod + if !t.checkPredictor(logger, targetPod) { + return + } + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.ResponseComplete: Failed to get SLO context for request", "requestID", id) + return + } + + if sloCtx.ttft > 0 { + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTTFT", sloCtx.ttft, "avgPredictedTTFT", sloCtx.predictedTTFT) + metrics.RecordRequestTTFT(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.ttft/1000) + metrics.RecordRequestPredictedTTFT(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.predictedTTFT/1000) + if sloCtx.ttftSLO > 0 { + metrics.RecordRequestTTFTWithSLO(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.ttft, sloCtx.ttftSLO) + } + } + + if sloCtx.avgTPOT > 0 { + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", sloCtx.avgTPOT, "avgPredictedTPOT", sloCtx.avgPredictedTPOT) + metrics.RecordRequestTPOT(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.avgTPOT/1000) + metrics.RecordRequestPredictedTPOT(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.avgPredictedTPOT/1000) + if sloCtx.avgTPOTSLO > 0 { + metrics.RecordRequestTPOTWithSLO(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.avgTPOT, sloCtx.avgTPOTSLO) + } + } + + logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.predictorBasedScheduling) + + podName := types.NamespacedName{ + Name: targetPod.NamespacedName.Name, + Namespace: targetPod.NamespacedName.Namespace, + } + + id := request.Headers[requtil.RequestIdHeaderKey] + podRequestList, ok := t.runningRequestLists[podName] + if !ok { + err := fmt.Errorf("no running request list found for pod %s", podName.String()) + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to remove request from queue", "requestID", id) + } + + _, removed := podRequestList.Remove(id) + if !removed { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item not found in queue", "podName", podName, "requestID", id) + } + t.deleteSLOContextForRequest(request) +} + +func (t *SLOAwareRouter) checkPredictor(logger logr.Logger, targetPod *backend.Pod) bool { + if targetPod == nil { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.") + return false + } + if t.latencypredictor == nil { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because predictor missing") + return false + } + return true +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go new file mode 100644 index 0000000000..5aaf1a2a24 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -0,0 +1,952 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" +) + +const ( + testModelName = "test-model" + kvUsage = 1 + runningQueue = 1 + waitingQueue = 1 +) + +// Helper functions + +func createTestSchedulingResult(pod *backend.Pod) *schedulingtypes.SchedulingResult { + + mockPod := createTestPod(pod.NamespacedName.Name, kvUsage, runningQueue, waitingQueue) + + return &schedulingtypes.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*schedulingtypes.ProfileRunResult{ + "default": { + TargetPods: []schedulingtypes.Pod{mockPod}, + }, + }, + } +} + +func createTestRouter() *SLOAwareRouter { + return &SLOAwareRouter{ + sloContextStore: sync.Map{}, + runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue), + latencypredictor: nil, + } +} + +// Test cases + +func TestNewSLORequestContext(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + + ctx := newSLORequestContext(request) + + assert.NotNil(t, ctx) + assert.Equal(t, *request, ctx.schedulingRequest) + assert.NotNil(t, ctx.lastSeenMetrics) + assert.NotNil(t, ctx.prefixCacheScoresForPods) + assert.NotNil(t, ctx.predictedTTFTForScheduling) + assert.NotNil(t, ctx.predictedTPOTForScheduling) + assert.Empty(t, ctx.lastSeenMetrics) + assert.Empty(t, ctx.prefixCacheScoresForPods) +} + +func TestSLOAwareRouter_SetAndGetSLOContext(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := newSLORequestContext(request) + + // Set context + router.setSLOContextForRequest(request, sloCtx) + + // Get context + retrievedCtx, err := router.getSLOContextForRequest(request) + + require.NoError(t, err) + assert.Equal(t, sloCtx, retrievedCtx) +} + +func TestSLOAwareRouter_GetSLOContext_NotFound(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + + // Try to get context that doesn't exist + ctx, err := router.getSLOContextForRequest(request) + + assert.Error(t, err) + assert.Nil(t, ctx) + assert.Contains(t, err.Error(), "SLO context not found") +} + +func TestSLOAwareRouter_DeleteSLOContext(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := newSLORequestContext(request) + + // Set and then delete context + router.setSLOContextForRequest(request, sloCtx) + router.deleteSLOContextForRequest(request) + + // Verify it's deleted + ctx, err := router.getSLOContextForRequest(request) + assert.Error(t, err) + assert.Nil(t, ctx) +} + +func TestSLOAwareRouter_PreRequest_NoSchedulingResult(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + + // Call PreRequest with nil scheduling result + router.PreRequest(ctx, request, nil) + + // Should not create SLO context + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_PreRequest_EmptySchedulingResult(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + + schedulingResult := &schedulingtypes.SchedulingResult{ + ProfileResults: map[string]*schedulingtypes.ProfileRunResult{}, + } + + // Call PreRequest with empty scheduling result + router.PreRequest(ctx, request, schedulingResult) + + // Should not create SLO context + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod()) + + // Create and set initial SLO context + sloCtx := newSLORequestContext(request) + sloCtx.avgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the request priority queue + router.runningRequestLists[pod.GetPod().NamespacedName] = newRequestPriorityQueue() + + beforeTime := time.Now() + router.PreRequest(ctx, request, schedulingResult) + afterTime := time.Now() + + // Verify SLO context was updated + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.Equal(t, pod.GetPod(), retrievedCtx.targetPod) + assert.Equal(t, schedulingResult, retrievedCtx.schedulingResult) + assert.True(t, retrievedCtx.requestReceivedTimestamp.After(beforeTime) || + retrievedCtx.requestReceivedTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.requestReceivedTimestamp.Before(afterTime) || + retrievedCtx.requestReceivedTimestamp.Equal(afterTime)) +} + +func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod()) + + // Create and set initial SLO context + sloCtx := newSLORequestContext(request) + sloCtx.avgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + + // PreRequest should create the queue + router.PreRequest(ctx, request, schedulingResult) + + // Verify queue was created and request was added + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists, "Queue should be created for pod") + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod()) + + // Create and set initial SLO contexts + sloCtx1 := newSLORequestContext(request1) + sloCtx1.avgTPOTSLO = 50 + router.setSLOContextForRequest(request1, sloCtx1) + + sloCtx2 := newSLORequestContext(request2) + sloCtx2.avgTPOTSLO = 50 + router.setSLOContextForRequest(request2, sloCtx2) + + // Add first request + router.PreRequest(ctx, request1, schedulingResult) + + // Add second request to same pod + router.PreRequest(ctx, request2, schedulingResult) + + // Verify both are in the same queue + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists) + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_ResponseReceived_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := newSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic and should return early + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // Context should still exist + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseReceived_NoPod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := newSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic with nil pod + router.ResponseReceived(ctx, request, response, nil) + + // Predictor should not be called + +} + +func TestSLOAwareRouter_ResponseReceived_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // Should handle missing context gracefully + +} + +func TestSLOAwareRouter_ResponseStreaming_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := newSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic and should return early + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Context should still exist + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} +func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod()) + + sloCtx := newSLORequestContext(request) + sloCtx.requestReceivedTimestamp = time.Now() + sloCtx.schedulingResult = schedulingResult + sloCtx.schedulingRequest = *request + sloCtx.ttftSLO = 100 + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = testModelName + sloCtx.predictedTTFT = 80.0 + sloCtx.avgPredictedTPOT = 30.0 + // ADD THIS - populate metrics + sloCtx.lastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + sloCtx.lastSeenMetrics["default"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the queue and add the request + queue := newRequestPriorityQueue() + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + + beforeTime := time.Now() + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + afterTime := time.Now() + + // Verify first token timestamp was set + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.True(t, retrievedCtx.lastTokenTimestamp.After(beforeTime) || + retrievedCtx.lastTokenTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.lastTokenTimestamp.Before(afterTime) || + retrievedCtx.lastTokenTimestamp.Equal(afterTime)) +} + +func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod()) + + sloCtx := newSLORequestContext(request) + sloCtx.requestReceivedTimestamp = time.Now() + sloCtx.schedulingResult = schedulingResult + sloCtx.schedulingRequest = *request + sloCtx.ttftSLO = 100 + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = testModelName + sloCtx.predictedTTFT = 80.0 + sloCtx.avgPredictedTPOT = 30.0 + // ADD THIS - populate metrics + sloCtx.lastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + sloCtx.lastSeenMetrics["default"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + firstTokenTime := time.Now().Add(-100 * time.Millisecond) + + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the queue and add the request + queue := newRequestPriorityQueue() + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Verify token timestamp was updated + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.True(t, retrievedCtx.lastTokenTimestamp.After(firstTokenTime)) +} + +func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := newSLORequestContext(request) + sloCtx.incomingModelName = testModelName + sloCtx.targetPod = pod.GetPod() // ADD THIS to avoid other issues + router.setSLOContextForRequest(request, sloCtx) + + // Create an EMPTY queue (not nil, but empty) to test queue.Remove behavior + router.runningRequestLists[pod.GetPod().NamespacedName] = newRequestPriorityQueue() + + // Should handle gracefully when request is not in queue + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Context should be deleted + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} +func TestSLOAwareRouter_ResponseStreaming_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context - should handle gracefully + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Should not panic + +} + +func TestSLOAwareRouter_ResponseComplete_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Create queue and add request + queue := newRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + + sloCtx := newSLORequestContext(request) + sloCtx.ttft = 80 + sloCtx.avgTPOT = 30 + sloCtx.predictedTTFT = 85 + sloCtx.avgPredictedTPOT = 32 + sloCtx.ttftSLO = 100 + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = "incoming-model" + router.setSLOContextForRequest(request, sloCtx) + + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify context was deleted + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) + + // Verify request was removed from queue + assert.Equal(t, 0, queue.Len()) +} + +func TestSLOAwareRouter_ResponseComplete_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := newSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Context should still exist (deletion happens only with predictor) + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoPod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := newSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic with nil pod + router.ResponseComplete(ctx, request, response, nil) + + // Context should still exist (deletion happens only with validpod.GetPod()) + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context - should handle gracefully + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Should not panic + +} + +func TestSLOAwareRouter_ResponseComplete_WithMetrics(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Create queue + queue := newRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + + sloCtx := newSLORequestContext(request) + sloCtx.ttft = 80 + sloCtx.avgTPOT = 30 + sloCtx.predictedTTFT = 85 + sloCtx.avgPredictedTPOT = 32 + sloCtx.ttftSLO = 100 + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = "incoming-model" + router.setSLOContextForRequest(request, sloCtx) + + // Should record metrics without panicking + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify cleanup + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test-id", 0, 0, true) // No SLOs + response := &requestcontrol.Response{} + + // Create queue + queue := newRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 0) + + sloCtx := newSLORequestContext(request) + sloCtx.ttft = 80 + sloCtx.avgTPOT = 30 + sloCtx.incomingModelName = testModelName + router.setSLOContextForRequest(request, sloCtx) + + // Should handle missing SLOs gracefully + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify cleanup + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_CheckPredictor_NilPod(t *testing.T) { + router := createTestRouter() + logger := logr.Discard() + + result := router.checkPredictor(logger, nil) + + assert.False(t, result) +} + +func TestSLOAwareRouter_CheckPredictor_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + logger := logr.Discard() + pod := createTestPod("test-pod", 1, 1, 1) + + result := router.checkPredictor(logger, pod.GetPod()) + + assert.False(t, result) +} + +func TestSLOAwareRouter_CheckPredictor_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + logger := logr.Discard() + pod := createTestPod("test-pod", 1, 1, 1) + + result := router.checkPredictor(logger, pod.GetPod()) + + assert.True(t, result) +} + +func TestSLORequestContext_Fields(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := newSLORequestContext(request) + + // Test all field initialization + assert.NotNil(t, ctx.lastSeenMetrics) + assert.NotNil(t, ctx.prefixCacheScoresForPods) + assert.NotNil(t, ctx.predictedTTFTForScheduling) + assert.NotNil(t, ctx.predictedTPOTForScheduling) + assert.Empty(t, ctx.tpotObservations) + assert.Empty(t, ctx.predictedTPOTObservations) + assert.Zero(t, ctx.generatedTokenCount) + assert.Zero(t, ctx.ttft) + assert.Zero(t, ctx.avgTPOT) + assert.Nil(t, ctx.targetPod) + assert.Nil(t, ctx.schedulingResult) + assert.Nil(t, ctx.tokenSampler) +} + +func TestSLORequestContext_UpdateMetrics(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := newSLORequestContext(request) + + // Add some metrics + metricsState := &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 3, + } + ctx.lastSeenMetrics["test-pod"] = metricsState + + assert.Len(t, ctx.lastSeenMetrics, 1) + assert.Equal(t, 0.5, ctx.lastSeenMetrics["test-pod"].KVCacheUsagePercent) + assert.Equal(t, 3, ctx.lastSeenMetrics["test-pod"].WaitingQueueSize) +} + +func TestSLORequestContext_PredictionData(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := newSLORequestContext(request) + + // Set prediction data + ctx.predictedTTFTForScheduling["pod1"] = 80.0 + ctx.predictedTPOTForScheduling["pod1"] = 30.0 + ctx.predictedTTFTForScheduling["pod2"] = 90.0 + ctx.predictedTPOTForScheduling["pod2"] = 35.0 + + assert.Len(t, ctx.predictedTTFTForScheduling, 2) + assert.Len(t, ctx.predictedTPOTForScheduling, 2) + assert.Equal(t, 80.0, ctx.predictedTTFTForScheduling["pod1"]) + assert.Equal(t, 30.0, ctx.predictedTPOTForScheduling["pod1"]) +} + +func TestSLORequestContext_PrefixCacheScores(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := newSLORequestContext(request) + + // Set prefix cache scores + ctx.prefixCacheScoresForPods["pod1"] = 0.8 + ctx.prefixCacheScoresForPods["pod2"] = 0.6 + ctx.prefixCacheScoresForPods["pod3"] = 0.9 + + assert.Len(t, ctx.prefixCacheScoresForPods, 3) + assert.Equal(t, 0.8, ctx.prefixCacheScoresForPods["pod1"]) + assert.Equal(t, 0.9, ctx.prefixCacheScoresForPods["pod3"]) +} + +func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { + router := createTestRouter() + + // Test concurrent access to context store + var wg sync.WaitGroup + numGoroutines := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + requestID := uuid.New().String() + request := createTestLLMRequest(requestID, 100, 50, true) + sloCtx := newSLORequestContext(request) + + // Set context + router.setSLOContextForRequest(request, sloCtx) + + // Get context + retrievedCtx, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) + assert.NotNil(t, retrievedCtx) + + // Delete context + router.deleteSLOContextForRequest(request) + }() + } + + wg.Wait() +} + +func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + request3 := createTestLLMRequest("test-id-3", 100, 50, true) + + schedulingResult := createTestSchedulingResult(pod.GetPod()) + + // Create and set SLO contexts + for _, req := range []*schedulingtypes.LLMRequest{request1, request2, request3} { + sloCtx := newSLORequestContext(req) + sloCtx.avgTPOTSLO = 50 + router.setSLOContextForRequest(req, sloCtx) + } + + // Add all requests + router.PreRequest(ctx, request1, schedulingResult) + router.PreRequest(ctx, request2, schedulingResult) + router.PreRequest(ctx, request3, schedulingResult) + + // Verify queue has all requests + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists) + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod()) + + // Create initial context + sloCtx := newSLORequestContext(request) + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = testModelName + router.setSLOContextForRequest(request, sloCtx) + + // 1. PreRequest + router.PreRequest(ctx, request, schedulingResult) + + // Verify context exists + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.NotNil(t, retrievedCtx.targetPod) + + // 2. ResponseReceived + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // 3. ResponseStreaming (first token) + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // 4. ResponseStreaming (subsequent tokens) + retrievedCtx, _ = router.getSLOContextForRequest(request) + retrievedCtx.ttft = 100 // Mark first token received + router.setSLOContextForRequest(request, retrievedCtx) + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // 5. ResponseComplete + retrievedCtx, _ = router.getSLOContextForRequest(request) + retrievedCtx.ttft = 80 + retrievedCtx.avgTPOT = 30 + router.setSLOContextForRequest(request, retrievedCtx) + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify context was cleaned up + _, err = router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + + pod1 := createTestPod("test-pod-1", 1, 1, 1) + pod2 := createTestPod("test-pod-2", 1, 1, 1) + + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + + schedulingResult1 := createTestSchedulingResult(pod1.GetPod()) + schedulingResult2 := createTestSchedulingResult(pod2.GetPod()) + + // Create and set SLO contexts + sloCtx1 := newSLORequestContext(request1) + sloCtx1.avgTPOTSLO = 50 + router.setSLOContextForRequest(request1, sloCtx1) + + sloCtx2 := newSLORequestContext(request2) + sloCtx2.avgTPOTSLO = 50 + router.setSLOContextForRequest(request2, sloCtx2) + + // Add requests to different pods + router.PreRequest(ctx, request1, schedulingResult1) + router.PreRequest(ctx, request2, schedulingResult2) + + // Verify separate queues were created + queue1, exists1 := router.runningRequestLists[pod1.GetPod().NamespacedName] + queue2, exists2 := router.runningRequestLists[pod2.GetPod().NamespacedName] + + assert.True(t, exists1) + assert.True(t, exists2) + assert.NotNil(t, queue1) + assert.NotNil(t, queue2) + assert.NotEqual(t, queue1, queue2) +} + +func TestSLORequestContext_SLOValidation(t *testing.T) { + tests := []struct { + name string + ttftSLO float64 + tpotSLO float64 + expectSLOs bool + }{ + { + name: "Both SLOs set", + ttftSLO: 100, + tpotSLO: 50, + expectSLOs: true, + }, + { + name: "No SLOs", + ttftSLO: 0, + tpotSLO: 0, + expectSLOs: false, + }, + { + name: "Only TTFT SLO", + ttftSLO: 100, + tpotSLO: 0, + expectSLOs: false, + }, + { + name: "Only TPOT SLO", + ttftSLO: 0, + tpotSLO: 50, + expectSLOs: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := createTestLLMRequest("test-id", tt.ttftSLO, tt.tpotSLO, true) + ctx := newSLORequestContext(request) + ctx.ttftSLO = tt.ttftSLO + ctx.avgTPOTSLO = tt.tpotSLO + + hasBothSLOs := ctx.ttftSLO > 0 && ctx.avgTPOTSLO > 0 + assert.Equal(t, tt.expectSLOs, hasBothSLOs) + }) + } +} + +// Benchmark tests + +func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + requestID := uuid.New().String() + request := createTestLLMRequest(requestID, 100, 50, true) + sloCtx := newSLORequestContext(request) + sloCtx.avgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + router.PreRequest(ctx, request, schedulingResult) + } +} + +func BenchmarkSLOAwareRouter_ContextOperations(b *testing.B) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := newSLORequestContext(request) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + router.setSLOContextForRequest(request, sloCtx) + _, _ = router.getSLOContextForRequest(request) + router.deleteSLOContextForRequest(request) + } +} + +func BenchmarkSLORequestContext_Creation(b *testing.B) { + request := createTestLLMRequest("test", 100, 50, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = newSLORequestContext(request) + } +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go new file mode 100644 index 0000000000..37017fbdcb --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go @@ -0,0 +1,243 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "container/heap" + "fmt" + "sort" + "strings" + "sync" +) + +// request represents an element in the priority queue. +// The index is needed by heap.Remove and is maintained by the heap.Interface methods. +type request struct { + id string // Unique identifier + tpot float64 // The priority value (lower is higher priority) + index int +} + +// requestPriorityQueue implements a priority queue with item removal by ID. +type requestPriorityQueue struct { + items []*request + lookup map[string]*request + mutex sync.RWMutex +} + +// newRequestPriorityQueue initializes and returns a new PriorityQueue. +func newRequestPriorityQueue() *requestPriorityQueue { + return &requestPriorityQueue{ + lookup: make(map[string]*request), + items: []*request{}, + } +} + +// Clone creates a deep copy of the priority queue. +// The new queue is completely independent of the original. +func (pq *requestPriorityQueue) Clone() *requestPriorityQueue { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + // Initialize a new priority queue with pre-allocated capacity. + clonedPq := &requestPriorityQueue{ + items: make([]*request, len(pq.items)), + lookup: make(map[string]*request, len(pq.lookup)), + } + + // Iterate through the original items to create deep copies. + for i, oldItem := range pq.items { + // Create a new Request struct, copying all values. + newItem := &request{ + id: oldItem.id, + tpot: oldItem.tpot, + index: oldItem.index, + } + + // Assign the new item to the cloned queue's items slice. + clonedPq.items[i] = newItem + // Update the lookup map in the cloned queue to point to the new item. + clonedPq.lookup[newItem.id] = newItem + } + + return clonedPq +} + +// Len is the number of items in the queue. +func (pq *requestPriorityQueue) Len() int { return len(pq.items) } + +// Less reports whether the item with index i should sort before the item with index j. +func (pq *requestPriorityQueue) Less(i, j int) bool { + return pq.items[i].tpot < pq.items[j].tpot +} + +// Swap swaps the items with indexes i and j. +func (pq *requestPriorityQueue) Swap(i, j int) { + pq.items[i], pq.items[j] = pq.items[j], pq.items[i] + pq.items[i].index = i + pq.items[j].index = j +} + +// Push adds an item to the heap. +func (pq *requestPriorityQueue) Push(x any) { + item := x.(*request) + item.index = len(pq.items) + pq.items = append(pq.items, item) +} + +// Pop removes and returns the minimum item from the heap. +func (pq *requestPriorityQueue) Pop() any { + n := len(pq.items) + item := pq.items[n-1] + pq.items[n-1] = nil // avoid memory leak + item.index = -1 // for safety + pq.items = pq.items[0 : n-1] + return item +} + +// Add adds a new item to the queue. +// Returns true if the item was added, false if an item with the same ID already exists. +func (pq *requestPriorityQueue) Add(id string, tpot float64) bool { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + // Validate input + if id == "" { + return false + } + if tpot < 0 { + return false + } + + // If item already exists, do not add + if _, exists := pq.lookup[id]; exists { + return false + } + + item := &request{ + id: id, + tpot: tpot, + } + pq.lookup[id] = item + heap.Push(pq, item) + return true +} + +// Update modifies the TPOT value of an existing item in the queue. +// If the item doesn't exist, this method does nothing. +func (pq *requestPriorityQueue) Update(id string, tpot float64) bool { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + // Validate input + if tpot < 0 { + return false + } + + item, exists := pq.lookup[id] + if !exists { + return false + } + + item.tpot = tpot + heap.Fix(pq, item.index) + return true +} + +// Remove removes an item from the queue by its ID. +func (pq *requestPriorityQueue) Remove(id string) (*request, bool) { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + item, ok := pq.lookup[id] + if !ok { + return nil, false + } + removed := heap.Remove(pq, item.index).(*request) + delete(pq.lookup, id) + return removed, true +} + +// Peek returns the item with the lowest value without removing it. +func (pq *requestPriorityQueue) Peek() *request { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + if len(pq.items) == 0 { + return nil + } + return pq.items[0] +} + +// GetSize returns the current number of items in the queue. +func (pq *requestPriorityQueue) GetSize() int { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + return len(pq.items) +} + +// Contains checks if an item with the given ID exists in the queue. +func (pq *requestPriorityQueue) Contains(id string) bool { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + _, exists := pq.lookup[id] + return exists +} + +// ToSlice returns a copy of all items in the queue, sorted by ID for stable comparison. +// This is primarily intended for testing and validation. +func (pq *requestPriorityQueue) ToSlice() []*request { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + // Create a copy to avoid returning a reference to the internal slice. + itemsCopy := make([]*request, len(pq.items)) + copy(itemsCopy, pq.items) + + // Sort by ID to have a deterministic order for comparison in tests. + sort.Slice(itemsCopy, func(i, j int) bool { + return itemsCopy[i].id < itemsCopy[j].id + }) + + return itemsCopy +} + +// String returns a string representation of the queue for debugging. +func (pq *requestPriorityQueue) String() string { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + if len(pq.items) == 0 { + return "RequestPriorityQueue: []" + } + + var builder strings.Builder + builder.WriteString("RequestPriorityQueue: [") + + for i, item := range pq.items { + if i > 0 { + builder.WriteString(", ") + } + builder.WriteString(item.id) + builder.WriteString("(") + builder.WriteString(fmt.Sprintf("%.2f", item.tpot)) + builder.WriteString(")") + } + + builder.WriteString("]") + return builder.String() +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go new file mode 100644 index 0000000000..ef34c84b50 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go @@ -0,0 +1,391 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestNewRequestPriorityQueue(t *testing.T) { + pq := newRequestPriorityQueue() + + if pq == nil { + t.Fatal("NewRequestPriorityQueue returned nil") + } + + if pq.GetSize() != 0 { + t.Errorf("Expected empty queue, got size %d", pq.GetSize()) + } + + if pq.Peek() != nil { + t.Error("Expected nil from Peek on empty queue") + } +} + +func TestAdd(t *testing.T) { + pq := newRequestPriorityQueue() + + // Test successful add + if !pq.Add("req1", 2.5) { + t.Error("Expected Add to return true for new item") + } + + if pq.GetSize() != 1 { + t.Errorf("Expected size 1, got %d", pq.GetSize()) + } + + // Test duplicate add + if pq.Add("req1", 3.0) { + t.Error("Expected Add to return false for duplicate ID") + } + + if pq.GetSize() != 1 { + t.Errorf("Expected size 1 after duplicate add, got %d", pq.GetSize()) + } + + // Test validation + if pq.Add("", 1.0) { + t.Error("Expected Add to return false for empty ID") + } + + if pq.Add("req2", -1.0) { + t.Error("Expected Add to return false for negative TPOT") + } +} + +func TestPriorityOrdering(t *testing.T) { + pq := newRequestPriorityQueue() + + // Add items with different priorities + pq.Add("high", 1.0) // highest priority (lowest TPOT) + pq.Add("medium", 5.0) // medium priority + pq.Add("low", 10.0) // lowest priority (highest TPOT) + + // Check that highest priority item is at the top + peek := pq.Peek() + if peek == nil || peek.id != "high" || peek.tpot != 1.0 { + t.Errorf("Expected high priority item at top, got %+v", peek) + } + + // Test removal order + expected := []struct { + id string + tpot float64 + }{ + {"high", 1.0}, + {"medium", 5.0}, + {"low", 10.0}, + } + + for _, exp := range expected { + item := pq.Peek() + if item.id != exp.id || item.tpot != exp.tpot { + t.Errorf("Expected %s(%.1f), got %s(%.1f)", exp.id, exp.tpot, item.id, item.tpot) + } + + removed, ok := pq.Remove(item.id) + if !ok || removed.id != exp.id { + t.Errorf("Failed to remove %s", exp.id) + } + } +} + +func TestRemove(t *testing.T) { + pq := newRequestPriorityQueue() + + // Test remove from empty queue + if _, ok := pq.Remove("nonexistent"); ok { + t.Error("Expected Remove to return false for empty queue") + } + + // Add some items + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Test successful remove + removed, ok := pq.Remove("req2") + if !ok || removed.id != "req2" || removed.tpot != 2.0 { + t.Errorf("Expected to remove req2(2.0), got %+v, ok=%v", removed, ok) + } + + if pq.GetSize() != 2 { + t.Errorf("Expected size 2 after removal, got %d", pq.GetSize()) + } + + // Test remove nonexistent + if _, ok := pq.Remove("req2"); ok { + t.Error("Expected Remove to return false for already removed item") + } + + // Verify remaining items are still in correct order + if peek := pq.Peek(); peek.id != "req1" { + t.Errorf("Expected req1 at top, got %s", peek.id) + } +} + +func TestUpdate(t *testing.T) { + pq := newRequestPriorityQueue() + + // Test update nonexistent item + if pq.Update("nonexistent", 1.0) { + t.Error("Expected Update to return false for nonexistent item") + } + + // Add items + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Update to make req3 highest priority + if !pq.Update("req3", 0.5) { + t.Error("Expected Update to return true for existing item") + } + + // Check that req3 is now at the top + if peek := pq.Peek(); peek.id != "req3" || peek.tpot != 0.5 { + t.Errorf("Expected req3(0.5) at top, got %s(%.1f)", peek.id, peek.tpot) + } + + // Test validation + if pq.Update("req1", -1.0) { + t.Error("Expected Update to return false for negative TPOT") + } +} + +func TestContains(t *testing.T) { + pq := newRequestPriorityQueue() + + // Test empty queue + if pq.Contains("req1") { + t.Error("Expected Contains to return false for empty queue") + } + + // Add item + pq.Add("req1", 1.0) + + // Test existing item + if !pq.Contains("req1") { + t.Error("Expected Contains to return true for existing item") + } + + // Test nonexistent item + if pq.Contains("req2") { + t.Error("Expected Contains to return false for nonexistent item") + } + + // Test after removal + pq.Remove("req1") + if pq.Contains("req1") { + t.Error("Expected Contains to return false after removal") + } +} + +func TestClone(t *testing.T) { + pq := newRequestPriorityQueue() + + // Test clone of empty queue + clone := pq.Clone() + if clone.GetSize() != 0 { + t.Error("Expected cloned empty queue to be empty") + } + + // Add items to original + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Clone with items + clone = pq.Clone() + + // Verify clone has same items + if clone.GetSize() != pq.GetSize() { + t.Errorf("Expected clone size %d, got %d", pq.GetSize(), clone.GetSize()) + } + + // Verify independence - modify original + pq.Add("req4", 4.0) + if clone.GetSize() == pq.GetSize() { + t.Error("Clone should be independent of original") + } + + // Verify independence - modify clone + clone.Remove("req1") + if !pq.Contains("req1") { + t.Error("Original should not be affected by clone modifications") + } + + // Verify deep copy - items should be different instances + origPeek := pq.Peek() + clonePeek := clone.Peek() + if origPeek == clonePeek { + t.Error("Clone should create new Request instances, not share pointers") + } +} + +func TestString(t *testing.T) { + pq := newRequestPriorityQueue() + + // Test empty queue + str := pq.String() + expected := "RequestPriorityQueue: []" + if str != expected { + t.Errorf("Expected %q, got %q", expected, str) + } + + // Test with items + pq.Add("req1", 1.5) + pq.Add("req2", 2.25) + + str = pq.String() + // Should contain both items in priority order + if !contains(str, "req1(1.50)") || !contains(str, "req2(2.25)") { + t.Errorf("String output missing expected items: %s", str) + } +} + +func TestConcurrency(t *testing.T) { + pq := newRequestPriorityQueue() + const numWorkers = 10 + const itemsPerWorker = 100 + + var wg sync.WaitGroup + + // Launch workers that add items + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + id := fmt.Sprintf("worker%d-item%d", workerID, j) + tpot := float64(j) + float64(workerID)*0.1 + pq.Add(id, tpot) + } + }(i) + } + + // Launch workers that read from the queue + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerWorker/2; j++ { + pq.Peek() + pq.GetSize() + time.Sleep(time.Microsecond) + } + }() + } + + wg.Wait() + + // Verify final state + expectedSize := numWorkers * itemsPerWorker + if pq.GetSize() != expectedSize { + t.Errorf("Expected final size %d, got %d", expectedSize, pq.GetSize()) + } +} + +func TestLargeQueue(t *testing.T) { + pq := newRequestPriorityQueue() + const numItems = 10000 + + // Add many items + for i := 0; i < numItems; i++ { + id := fmt.Sprintf("item%d", i) + tpot := float64(numItems - i) // Reverse order so item0 has highest priority + pq.Add(id, tpot) + } + + if pq.GetSize() != numItems { + t.Errorf("Expected size %d, got %d", numItems, pq.GetSize()) + } + + // Verify priority ordering by removing items + lastTPOT := -1.0 + for i := 0; i < numItems; i++ { + item := pq.Peek() + if item.tpot < lastTPOT { + t.Errorf("Priority order violated: %.1f < %.1f", item.tpot, lastTPOT) + } + lastTPOT = item.tpot + pq.Remove(item.id) + } + + if pq.GetSize() != 0 { + t.Errorf("Expected empty queue after removing all items, got size %d", pq.GetSize()) + } +} + +func BenchmarkAdd(b *testing.B) { + pq := newRequestPriorityQueue() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := fmt.Sprintf("item%d", i) + pq.Add(id, float64(i)) + } +} + +func BenchmarkPeek(b *testing.B) { + pq := newRequestPriorityQueue() + + // Pre-populate queue + for i := 0; i < 1000; i++ { + pq.Add(fmt.Sprintf("item%d", i), float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Peek() + } +} + +func BenchmarkRemove(b *testing.B) { + pq := newRequestPriorityQueue() + + // Pre-populate queue + for i := 0; i < b.N; i++ { + pq.Add(fmt.Sprintf("item%d", i), float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Remove(fmt.Sprintf("item%d", i)) + } +} + +// Helper function to check if a string contains a substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && + (s == substr || + s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go new file mode 100644 index 0000000000..13d2543a60 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go @@ -0,0 +1,111 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "hash/fnv" + "math" + "math/rand" + "time" +) + +// tokenSampler handles Poisson-distributed sampling for predictions only +// Training happens on every token regardless of sampling +type tokenSampler struct { + rng *rand.Rand + nextSampleToken int + samplingMean float64 + maxSamples int + sampleCount int +} + +func newTokenSampler(requestID string, samplingMean float64, maxSamples int) *tokenSampler { + // Use request ID hash as seed for reproducibility + seed := int64(0) + if requestID != "" { + hash := fnv.New64a() + hash.Write([]byte(requestID)) + seed = int64(hash.Sum64()) + } + if seed == 0 { + seed = time.Now().UnixNano() + } + + sampler := &tokenSampler{ + rng: rand.New(rand.NewSource(seed)), + samplingMean: samplingMean, + maxSamples: maxSamples, + } + + // Set first sample token (skip token 1 since that's TTFT) + sampler.nextSampleToken = 2 + sampler.poissonNext() + + return sampler +} + +// poissonNext generates the next interval using Poisson distribution +func (ts *tokenSampler) poissonNext() int { + lambda := ts.samplingMean + if lambda <= 0 { + return 1 + } + + // For small lambda, use Knuth's algorithm + if lambda < 30 { + l := math.Exp(-lambda) + k := 0 + p := 1.0 + + for p > l { + k++ + p *= ts.rng.Float64() + } + return k - 1 + } + + // For larger lambda, use normal approximation + normal := ts.rng.NormFloat64() + interval := int(math.Round(lambda + math.Sqrt(lambda)*normal)) + if interval < 1 { + return 1 + } + return interval +} + +// shouldPredict determines if we should make a prediction for the current token +func (ts *tokenSampler) shouldPredict(currentToken int) bool { + return currentToken == ts.nextSampleToken && ts.sampleCount < ts.maxSamples +} + +// recordPrediction records that a prediction was made and calculates the next sample token +func (ts *tokenSampler) recordPrediction(currentToken int) { + if ts.sampleCount >= ts.maxSamples { + return + } + + ts.sampleCount++ + + if ts.sampleCount < ts.maxSamples { + interval := ts.poissonNext() + ts.nextSampleToken = currentToken + interval + } +} + +// getNextSampleToken returns the next token to predict for +func (ts *tokenSampler) getNextSampleToken() int { + return ts.nextSampleToken +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go new file mode 100644 index 0000000000..8d05418e43 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -0,0 +1,264 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +type SLOAwareRouter struct { + tn plugins.TypedName + latencypredictor latencypredictor.PredictorInterface + runningRequestLists map[types.NamespacedName]*requestPriorityQueue + sloContextStore sync.Map // map[string]*SLORequestContext + headroomStrategy headroomStrategy +} + +var _ framework.Scorer = &SLOAwareRouter{} + +func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy headroomStrategy) *SLOAwareRouter { + return &SLOAwareRouter{ + tn: plugins.TypedName{Type: SLOAwareRouterPluginType, Name: SLOAwareRouterPluginType}, + latencypredictor: latencypredictor, + runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue), + sloContextStore: sync.Map{}, + headroomStrategy: strategy, + } +} + +func (s *SLOAwareRouter) TypedName() plugins.TypedName { + return s.tn +} + +func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter { + s.tn.Name = name + return s +} + +func (s *SLOAwareRouter) epsilonGreedyAffinityGate( + ctx context.Context, + candidates []podPredictionResult, + r *rand.Rand, + label string, // e.g. "positive" or "negative" + prefixStickyThreshold float64, +) ([]podPredictionResult, bool) { + logger := log.FromContext(ctx) + if prefixStickyThreshold <= 0 { + // Affinity gating disabled + logger.V(logutil.DEBUG).Info("Affinity gating disabled (threshold <= 0)", "path", label) + return candidates, false + } + eligible := make([]podPredictionResult, 0, len(candidates)) + for _, p := range candidates { + if p.PrefixCacheScore >= prefixStickyThreshold { + eligible = append(eligible, p) + } + } + + // No eligible sticky pods? Explore (no gating). + if len(eligible) == 0 { + return candidates, false + } + + // ε-exploration branch + if r.Float64() < EpsilonExploreSticky { + logger.V(logutil.DEBUG).Info("ε-greedy: exploring (ignoring affinity gate)", + "path", label, "epsilon", EpsilonExploreSticky, "eligibleCount", len(eligible)) + return candidates, false + } + + logger.V(logutil.DEBUG).Info("ε-greedy: exploiting (apply affinity gate)", + "path", label, "threshold", prefixStickyThreshold, "eligibleCount", len(eligible), "total", len(candidates)) + return eligible, true +} + +// scoreWithoutPredictions provides fallback scoring based only on prefix cache scores +// when latency predictions are unavailable +func (s *SLOAwareRouter) scoreWithoutPredictions( + ctx context.Context, + state *schedulingtypes.CycleState, + pods []schedulingtypes.Pod, + r *rand.Rand, +) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + logger.V(logutil.TRACE).Info("Using composite-only scoring without predictions") + + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + if len(pods) == 0 { + return scores + } + + // Build prediction results with only prefix cache scores + podResults := make([]podPredictionResult, 0, len(pods)) + for _, pod := range pods { + prefixScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + podResults = append(podResults, podPredictionResult{ + Pod: pod, + PrefixCacheScore: prefixScore, + IsValid: true, // All pods are valid when we don't check predictions + }) + } + + // Select based on composite scores (prefix cache + other non-prediction metrics) + selectedPod := s.selectFromCompositeScores(ctx, podResults, r, headroomStrategyCompositeOnly) + + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.TRACE).Info("Selected pod using composite-only scoring", "pod", selectedPod.GetPod().String()) + } + + return scores +} + +func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + if s.latencypredictor == nil { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: no predictor configured, returning nil scores") + return nil + } + + sloCtx := s.getOrMakeSLORequestContext(request) + + s.parseSLOHeaders(ctx, request, sloCtx) + + // Check if SLOs are provided + if !sloCtx.predictorBasedScheduling { + logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") + s.setSLOContextForRequest(request, sloCtx) + return nil + } + + // Initialize scores map with all pods having score 0 + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) + predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring") + // Fall back to composite-only scoring using prefix cache scores + s.setSLOContextForRequest(request, sloCtx) + return s.scoreWithoutPredictions(ctx, state, pods, r) + } + s.updateRequestContextWithPredictions(sloCtx, predictions) + + allPreds := append([]podPredictionResult(nil), predictions...) + allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) + + // Check if all pods are invalid and all have running requests + allPodsInvalid := true + allPodsHaveRunningRequests := true + + for _, pred := range allPreds { + if pred.IsValid { + allPodsInvalid = false + } + + runningRequestCount := s.getPodRunningRequestCount(pred.Pod) + if runningRequestCount == 0 { + allPodsHaveRunningRequests = false + } + } + + // Set HasValidPod to false if all pods are invalid and all have running requests + if allPodsInvalid && allPodsHaveRunningRequests && !sticky { + sloCtx.hasValidPod = false + logger.V(logutil.DEBUG).Info("All pods are invalid and have running requests, setting HasValidPod to false") + } + + // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% + posHeadroomPods, negHeadroomPods := s.classifyPodsByHeadroom(allPreds) + + logger.V(logutil.DEBUG).Info("Pod headroom distribution", + "positivePods", len(posHeadroomPods), + "negativePods", len(negHeadroomPods)) + + selectedPod := s.selectPodBasedOnStrategy(ctx, r, allPreds, posHeadroomPods, negHeadroomPods) + + // Set score = 1 for selected pod, 0 for all others + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.DEBUG).Info("Selected pod for scheduling", "pod", selectedPod.GetPod().String()) + } + + s.setSLOContextForRequest(request, sloCtx) + + return scores +} + +func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLMRequest) *sloRequestContext { + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + sloCtx = newSLORequestContext(request) + } + return sloCtx +} + +func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 { + log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String()) + plugintype := prefix.PrefixCachePluginType + pluginname := prefix.PrefixCachePluginType + cycleStateKey := (plugins.TypedName{Type: plugintype, Name: pluginname}).String() + stateData, err := cycleState.Read(plugins.StateKey(cycleStateKey)) + + log.FromContext(ctx).V(logutil.DEBUG).Info("Reading prefix cache state from cycle state", "stateKey", cycleStateKey) + + if err != nil { + // The prefix cache plugin might not be enabled, which is a valid scenario. + log.FromContext(ctx).V(logutil.DEBUG).Info("prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) + return 0.0 + } + + prefixCacheState, ok := stateData.(*prefix.SchedulingContextState) + if !ok { + // This should not happen if the plugin is configured correctly. + log.FromContext(ctx).Error(fmt.Errorf("unexpected state type: %T", stateData), "failed to read prefix cache state") + return 0.0 + } + + total := len(prefixCacheState.PrefixHashes) + if total == 0 { + // if the request has no prefixes, return 0.0 + log.FromContext(ctx).V(logutil.DEBUG).Info("No prefixes found in request, returning prefix cache score of 0.0") + return 0.0 + } + + matchLen := prefixCacheState.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "matchLen", matchLen, "totalPrefixes", total) + return float64(matchLen) / float64(total) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go new file mode 100644 index 0000000000..5b3fea8887 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go @@ -0,0 +1,102 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "math/rand" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) parseSLOHeaders(ctx context.Context, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext) { + logger := log.FromContext(ctx) + var err error + + // Get Request SLOs from request header + sloCtx.ttftSLO, err = parseFloatHeader(*request, ttftSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", ttftSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") + } + + sloCtx.avgTPOTSLO, err = parseFloatHeader(*request, tpotSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", tpotSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") + } + sloCtx.predictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") + } +} + +func (s *SLOAwareRouter) classifyPodsByHeadroom(allPreds []podPredictionResult) (posHeadroomPods, negHeadroomPods []podPredictionResult) { + for _, p := range allPreds { + // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom + if p.Headroom > 0 && p.TTFTHeadroom > 0 { + posHeadroomPods = append(posHeadroomPods, p) + } else { + // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom + negHeadroomPods = append(negHeadroomPods, p) + } + } + return +} + +func (s *SLOAwareRouter) selectPodBasedOnStrategy( + ctx context.Context, + r *rand.Rand, + allPreds, posHeadroomPods, negHeadroomPods []podPredictionResult, +) schedulingtypes.Pod { + logger := log.FromContext(ctx) + var selectedPod schedulingtypes.Pod + + switch { + case s.headroomStrategy == headroomStrategyCompositeOnly: + logger.V(logutil.DEBUG).Info("Selecting from composite scores only") + selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, headroomStrategyCompositeOnly) + case len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0: + // 99% chance to select from positive headroom pods, 1% from negative + if r.Float64() < EpsilonExploreNeg { + logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else { + logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } + case len(posHeadroomPods) > 0: + // If only positive headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only positive headroom pods available") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + case len(negHeadroomPods) > 0: + // If only negative headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only negative headroom pods available") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + case len(allPreds) > 0: + // fallback - select randomly from valid pods + logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") + selectedPod = allPreds[r.Intn(len(allPreds))].Pod + default: + // No valid pods - return nil (caller handles this) + logger.V(logutil.DEBUG).Info("No valid pods available") + return nil + } + return selectedPod +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go new file mode 100644 index 0000000000..a15cb29ac4 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -0,0 +1,516 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "errors" + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +// mockPredictor implements PredictorInterface for testing +type mockPredictor struct { + predictions map[string]*latencypredictor.PredictionResponse + err error +} + +func (m *mockPredictor) Predict(ctx context.Context, request latencypredictor.PredictionRequest) (*latencypredictor.PredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + return pred, nil + } + // Default prediction + return &latencypredictor.PredictionResponse{TTFT: 0.5, TPOT: 0.03}, nil +} + +func (m *mockPredictor) PredictBulk(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + responses := make([]latencypredictor.PredictionResponse, 0, len(requests)) + for _, request := range requests { + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + responses = append(responses, *pred) + } else { + return nil, fmt.Errorf("no prediction for key %s", key) + } + } + return &latencypredictor.BulkPredictionResponse{Predictions: responses}, nil +} + +func (m *mockPredictor) PredictBulkStrict(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + responses := make([]latencypredictor.PredictionResponse, 0, len(requests)) + for _, request := range requests { + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + responses = append(responses, *pred) + } else { + return nil, fmt.Errorf("no prediction for key %s", key) + } + } + return &latencypredictor.BulkPredictionResponse{Predictions: responses}, nil +} + +func (m *mockPredictor) AddTrainingDataBulk(data []latencypredictor.TrainingEntry) error { + return nil +} + +func (m *mockPredictor) AddTrainingData(data latencypredictor.TrainingEntry) error { + return nil +} + +func (m *mockPredictor) HealthCheck() error { + return nil +} + +func (m *mockPredictor) GetServerStatus(ctx context.Context) (*latencypredictor.ServerStatusResponse, error) { + return &latencypredictor.ServerStatusResponse{}, nil +} + +func createTestPod(name string, kvCacheUsage float64, runningQueueSize, waitingQueueSize int) schedulingtypes.Pod { + return &schedulingtypes.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: types.NamespacedName{ + Name: name, + Namespace: "default", + }, + }, + MetricsState: &backendmetrics.MetricsState{ + KVCacheUsagePercent: kvCacheUsage, + RunningQueueSize: runningQueueSize, + WaitingQueueSize: waitingQueueSize, + }, + } +} + +func createTestLLMRequest(reqID string, ttftSLO, tpotSLO float64, predictionBased bool) *schedulingtypes.LLMRequest { + headers := make(map[string]string) + headers[requtil.RequestIdHeaderKey] = reqID + if ttftSLO > 0 { + headers["x-ttft-slo"] = fmt.Sprintf("%f", ttftSLO) + } + if tpotSLO > 0 { + headers["x-avg-tpot-slo"] = fmt.Sprintf("%f", tpotSLO) + } + headers["x-prediction-based-scheduling"] = strconv.FormatBool(predictionBased) + + return &schedulingtypes.LLMRequest{ + Headers: headers, + Body: &schedulingtypes.LLMRequestBody{ + Completions: &schedulingtypes.CompletionsRequest{ + Prompt: "test prompt", + }, + }, + } +} + +func TestSLOAwareRouter_Score(t *testing.T) { + tests := []struct { + name string + predictor *mockPredictor + strategy headroomStrategy + request *schedulingtypes.LLMRequest + pods []schedulingtypes.Pod + expectedScores map[string]float64 // Map of pod name to expected score + expectNil bool + }{ + { + name: "Prediction-based scheduling disabled", + predictor: &mockPredictor{}, + strategy: headroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, false), // predictionBased = false + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting + createTestPod("pod2", 0.7, 3, 2), // 70% KV cache, 3 running, 2 waiting + }, + expectNil: true, + }, + { + name: "No predictor configured", + predictor: nil, + strategy: headroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + }, + expectNil: true, + }, + { + name: "All pods have positive headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, // 50% KV cache + "0.6": {TTFT: 0.6, TPOT: 0.04}, // 60% KV cache + "0.3": {TTFT: 0.4, TPOT: 0.02}, // 30% KV cache + }, + }, + strategy: headroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), // 50% KV cache + createTestPod("pod2", 0.6, 3, 2), // 60% KV cache + createTestPod("pod3", 0.3, 1, 0), // 30% KV cache + }, + // One pod should be selected with score 1, others 0 + expectedScores: map[string]float64{ + // We can't predict which one due to randomness, but exactly one should be 1 + }, + }, + { + name: "All pods have negative headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.8": {TTFT: 1.5, TPOT: 0.08}, // 80% KV cache - high load + "0.9": {TTFT: 1.8, TPOT: 0.09}, // 90% KV cache - very high load + }, + }, + strategy: headroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.8, 5, 3), // 80% KV cache, high load + createTestPod("pod2", 0.9, 6, 4), // 90% KV cache, very high load + }, + // One pod should still be selected even with negative headroom + expectedScores: map[string]float64{}, + }, + { + name: "Mixed positive and negative headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.3": {TTFT: 0.5, TPOT: 0.03}, // 30% KV cache - Positive headroom + "0.9": {TTFT: 1.5, TPOT: 0.08}, // 90% KV cache - Negative headroom + }, + }, + strategy: headroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom + createTestPod("pod-negative", 0.9, 6, 4), // High KV cache, negative headroom + }, + // With 99% probability, positive headroom pod should be selected + expectedScores: map[string]float64{}, + }, + { + name: "Prediction errors - fallback to composite scoring", + predictor: &mockPredictor{ + err: errors.New("prediction failed"), + }, + strategy: headroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + createTestPod("pod2", 0.6, 3, 2), + }, + // Should fall back to composite-only scoring and select one pod + expectedScores: map[string]float64{ + // One pod should be selected with score 1, verified in general validation below + }, + }, + { + name: "Empty pod list", + predictor: &mockPredictor{}, + strategy: headroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{}, + // Should return empty scores map + expectedScores: map[string]float64{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var router *SLOAwareRouter + if tt.predictor == nil { + router = NewSLOAwareRouter(nil, tt.strategy) + } else { + router = NewSLOAwareRouter(tt.predictor, tt.strategy) + } + + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), tt.request, tt.pods) + + if tt.expectNil { + assert.Nil(t, scores, "Expected nil scores") + return + } + + assert.NotNil(t, scores, "Expected non-nil scores") + + // If we have specific expected scores, verify them + if len(tt.expectedScores) > 0 { + for _, pod := range tt.pods { + podName := pod.GetPod().NamespacedName.Name + if expectedScore, ok := tt.expectedScores[podName]; ok { + assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %s should have score %f", podName, expectedScore) + } + } + } + + // General validation: exactly one pod should have score 1 (selected), others should have score 0 + // This applies even when predictions fail because we fall back to composite scoring + if !tt.expectNil && len(tt.pods) > 0 && tt.predictor != nil { + selectedCount := 0 + for _, score := range scores { + if score == 1.0 { + selectedCount++ + } else { + assert.InDelta(t, 0.0, score, 0.0001, "Non-selected pods should have score 0") + } + } + assert.Equal(t, 1, selectedCount, "Exactly one pod should be selected with score 1") + } + }) + } +} + +func TestSLOAwareRouter_Strategies(t *testing.T) { + tests := []struct { + name string + strategy headroomStrategy + }{ + { + name: "HeadroomStrategyLeast", + strategy: headroomStrategyLeast, + }, + { + name: "HeadroomStrategyMost", + strategy: headroomStrategyMost, + }, + { + name: "HeadroomStrategyCompositeMost", + strategy: headroomStrategyCompositeMost, + }, + { + name: "HeadroomStrategyCompositeLeast", + strategy: headroomStrategyCompositeLeast, + }, + { + name: "HeadroomStrategyCompositeOnly", + strategy: headroomStrategyCompositeOnly, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, + "0.6": {TTFT: 0.6, TPOT: 0.04}, + "0.3": {TTFT: 0.4, TPOT: 0.02}, + }, + } + router := NewSLOAwareRouter(predictor, tt.strategy) + + request := createTestLLMRequest("test", 1.0, 0.05, true) + pods := []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + createTestPod("pod2", 0.6, 3, 2), + createTestPod("pod3", 0.3, 1, 0), + } + + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), request, pods) + + assert.NotNil(t, scores, "Expected non-nil scores for strategy %s", tt.strategy) + + // Verify exactly one pod is selected + selectedCount := 0 + for _, score := range scores { + if score == 1.0 { + selectedCount++ + } + } + assert.Equal(t, 1, selectedCount, "Strategy %s should select exactly one pod", tt.strategy) + }) + } +} + +func TestSLOAwareRouter_TypedName(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) + + tn := router.TypedName() + assert.Equal(t, "slo-aware-routing", tn.Type, "Type should be slo-aware-routing") + assert.Equal(t, "slo-aware-routing", tn.Name, "Default name should be slo-aware-routing") +} + +func TestSLOAwareRouter_WithName(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) + + customName := "custom-router" + router = router.WithName(customName) + + tn := router.TypedName() + assert.Equal(t, "slo-aware-routing", tn.Type, "Type should remain slo-aware-routing") + assert.Equal(t, customName, tn.Name, "Name should be updated to custom name") +} + +func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { + tests := []struct { + name string + setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + expectedCount int + }{ + { + name: "No running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + expectedCount: 0, + }, + { + name: "One running request", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = newRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + }, + expectedCount: 1, + }, + { + name: "Multiple running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = newRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + r.runningRequestLists[podName].Add("req2", 0.03) + r.runningRequestLists[podName].Add("req3", 0.05) + }, + expectedCount: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) + pod := createTestPod("test-pod", 0.5, 2, 1) + + tt.setupRequests(router, pod) + + count := router.getPodRunningRequestCount(pod) + assert.Equal(t, tt.expectedCount, count, "Running request count should match expected") + }) + } +} + +func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { + tests := []struct { + name string + setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + expectedSLO float64 + }{ + { + name: "No running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + expectedSLO: 0.0, + }, + { + name: "One running request", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = newRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + }, + expectedSLO: 0.04, + }, + { + name: "Multiple running requests - should return minimum", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = newRequestPriorityQueue() + // Add in any order - heap will maintain minimum at top + r.runningRequestLists[podName].Add("req1", 0.05) + r.runningRequestLists[podName].Add("req2", 0.03) // This is the minimum + r.runningRequestLists[podName].Add("req3", 0.04) + }, + expectedSLO: 0.03, // Minimum TPOT (heap guarantees this is at items[0]) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) + pod := createTestPod("test-pod", 0.5, 2, 1) + + tt.setupRequests(router, pod) + + minSLO := router.getPodMinTPOTSLO(pod) + assert.InDelta(t, tt.expectedSLO, minSLO, 0.0001, "Min TPOT SLO should match expected") + }) + } +} + +func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) { + tests := []struct { + name string + setupState func(*schedulingtypes.CycleState) + expectedScore float64 + }{ + { + name: "No prefix cache state", + setupState: func(s *schedulingtypes.CycleState) {}, + expectedScore: 0.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) + + state := schedulingtypes.NewCycleState() + tt.setupState(state) + + pod := createTestPod("test-pod", 0.5, 2, 1) + + score := router.getPrefixCacheScoreForPod(context.Background(), state, pod) + assert.InDelta(t, tt.expectedScore, score, 0.0001, "Prefix cache score should match expected") + }) + } +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go new file mode 100644 index 0000000000..02a99cc6e8 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -0,0 +1,310 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy +// Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. +func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { + + if len(posHeadroomPods) == 1 { + return posHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, posHeadroomPods, r, "positive", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + switch s.headroomStrategy { + case headroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeMost) + case headroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeLeast) + } + + // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] + minTPOTH, maxTPOTH, minTTFTH, maxTTFTH := s.calculateHeadroomRanges(candidates) + + // Calculate weights for weighted random selection + weightedChoices, total := s.calculateWeightedChoices(ctx, candidates, minTPOTH, maxTPOTH, minTTFTH, maxTTFTH) + + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) +} + +// selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic +// Modified to strictly prefer pods with 0 running requests +func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // First, separate pods by running request count + var zeroRunningRequestPods, nonZeroRunningRequestPods []podPredictionResult + + for _, p := range negHeadroomPods { + runningRequestCount := s.getPodRunningRequestCount(p.Pod) + if runningRequestCount == 0 { + zeroRunningRequestPods = append(zeroRunningRequestPods, p) + } else { + nonZeroRunningRequestPods = append(nonZeroRunningRequestPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Negative headroom pods by running request count", + "zeroRunningRequests", len(zeroRunningRequestPods), + "nonZeroRunningRequests", len(nonZeroRunningRequestPods)) + + // If we have pods with 0 running requests, strictly prefer them + if len(zeroRunningRequestPods) > 0 { + logger.V(logutil.DEBUG).Info("Selecting from pods with zero running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, zeroRunningRequestPods, r) + } + + // Otherwise, fall back to pods with running requests + logger.V(logutil.DEBUG).Info("No pods with zero running requests, selecting from pods with running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, nonZeroRunningRequestPods, r) +} + +// selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods +func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, negHeadroomPods, r, "negative", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + + switch s.headroomStrategy { + case headroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeMost) + case headroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeMost) + } + + // Build weighted choices for selection + weightedChoices := make([]choice, 0, len(candidates)) + total := 0 + + s.handleNegativeHeadroomPodsHierarchical(ctx, candidates, &weightedChoices, &total, minWeight) + + // Perform weighted random selection + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) +} + +// weightPodsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits. +// Lower blended deficit => higher weight. +func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( + ctx context.Context, + pods []podPredictionResult, + choices *[]choice, + total *int, + minWeight int, + alpha, beta float64, // weights for TTFT and TPOT deficits + category string, +) { + logger := log.FromContext(ctx) + if len(pods) == 0 { + return + } + + const Wrange = 80 + const eps = 1e-9 + + // Compute raw deficits (only when headroom is negative) + type deficits struct { + pod podPredictionResult + ttftDef float64 + tpotDef float64 + } + defs := make([]deficits, 0, len(pods)) + + minTTFT, maxTTFT := math.MaxFloat64, -math.MaxFloat64 + minTPOT, maxTPOT := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range pods { + ttftDef := 0.0 + if p.TTFTHeadroom < 0 { + ttftDef = -p.TTFTHeadroom + } + tpotDef := 0.0 + if p.Headroom < 0 { + tpotDef = -p.Headroom + } + defs = append(defs, deficits{pod: p, ttftDef: ttftDef, tpotDef: tpotDef}) + + if ttftDef < minTTFT { + minTTFT = ttftDef + } + if ttftDef > maxTTFT { + maxTTFT = ttftDef + } + if tpotDef < minTPOT { + minTPOT = tpotDef + } + if tpotDef > maxTPOT { + maxTPOT = tpotDef + } + } + + ttftRange := maxTTFT - minTTFT + tpotRange := maxTPOT - minTPOT + + // Normalize alpha/beta + if alpha+beta <= 0 { + alpha, beta = 1.0, 0.0 + } else { + sum := alpha + beta + alpha /= sum + beta /= sum + } + + logger.V(logutil.DEBUG).Info("Negative headroom blended deficits", + "category", category, + "minTTFTDef", minTTFT, "maxTTFTDef", maxTTFT, + "minTPOTDef", minTPOT, "maxTPOTDef", maxTPOT, + "alphaTTFT", alpha, "betaTPOT", beta, "podCount", len(pods)) + + for _, d := range defs { + // Normalize deficits to [0,1] within this bucket (0 = best / least violation) + nTTFT := 0.0 + if ttftRange > eps { + nTTFT = (d.ttftDef - minTTFT) / (ttftRange + eps) + } + nTPOT := 0.0 + if tpotRange > eps { + nTPOT = (d.tpotDef - minTPOT) / (tpotRange + eps) + } + + // Blended "badness": higher = worse violation + blended := alpha*nTTFT + beta*nTPOT + + // Convert to selection weight: lower badness -> higher weight + // Ensure a floor so no pod is completely excluded within the bucket. + w := int((1.0-blended)*float64(Wrange)) + minWeight + 1 + + *choices = append(*choices, choice{podName: d.pod.Pod, weight: w}) + *total += w + + logger.V(logutil.TRACE).Info("Negative bucket blended weighting", + "pod", d.pod.Pod.GetPod().String(), + "ttftDef", d.ttftDef, "tpotDef", d.tpotDef, + "normTTFT", nTTFT, "normTPOT", nTPOT, + "blendedBadness", blended, "weight", w) + } +} + +func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( + ctx context.Context, + negHeadroomPods []podPredictionResult, + choices *[]choice, + total *int, + minWeightForNegative int, +) { + logger := log.FromContext(ctx) + + // Categorize pods by their headroom status + var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []podPredictionResult + + for _, p := range negHeadroomPods { + switch { + case p.TTFTHeadroom < 0 && p.Headroom < 0: + negTTFTNegTPOT = append(negTTFTNegTPOT, p) + case p.TTFTHeadroom < 0 && p.Headroom >= 0: + negTTFTNonNegTPOT = append(negTTFTNonNegTPOT, p) + case p.TTFTHeadroom >= 0 && p.Headroom < 0: + nonNegTTFTNegTPOT = append(nonNegTTFTNegTPOT, p) + default: + nonNegTTFTNonNegTPOT = append(nonNegTTFTNonNegTPOT, p) + } + } + + logger.V(logutil.DEBUG).Info("Hierarchical negative headroom pod distribution", + "totalNegative", len(negHeadroomPods), + "negTTFT_negTPOT", len(negTTFTNegTPOT), + "negTTFT_nonNegTPOT", len(negTTFTNonNegTPOT), + "nonNegTTFT_negTPOT", len(nonNegTTFTNegTPOT), + "nonNegTTFT_nonNegTPOT", len(nonNegTTFTNonNegTPOT)) + + // Priority 1: both TTFT and TPOT negative -> blended deficits (both active) + if len(negTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "both_negative") + } + + // Priority 2: TTFT negative, TPOT non-negative -> blended still works (TPOT deficit=0) + if len(negTTFTNonNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNonNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "ttft_negative") + } + + // Priority 3: TTFT non-negative, TPOT negative -> blended (TTFT deficit=0) + if len(nonNegTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, nonNegTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "tpot_negative") + } + + // Priority 4: edge-case bucket -> minimal weight + for _, p := range nonNegTTFTNonNegTPOT { + *choices = append(*choices, choice{podName: p.Pod, weight: minWeightForNegative}) + *total += minWeightForNegative + } +} + +func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 { + if topReq := runningReqs.Peek(); topReq != nil { + return topReq.tpot + } + } + return 0 // no running requests or no TPOT SLOs +} + +func (s *SLOAwareRouter) getPodRunningRequestCount(pod schedulingtypes.Pod) int { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok { + return runningReqs.GetSize() + } + return 0 // no running requests +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go new file mode 100644 index 0000000000..cdf3d965a1 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go @@ -0,0 +1,114 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "math" + + "sigs.k8s.io/controller-runtime/pkg/log" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) calculateHeadroomRanges(candidates []podPredictionResult) (minTPOTH, maxTPOTH, minTTFTH, maxTTFTH float64) { + minTPOTH, maxTPOTH = math.MaxFloat64, -math.MaxFloat64 + minTTFTH, maxTTFTH = math.MaxFloat64, -math.MaxFloat64 + + for _, p := range candidates { + if p.Headroom < minTPOTH { + minTPOTH = p.Headroom + } + if p.Headroom > maxTPOTH { + maxTPOTH = p.Headroom + } + if p.TTFTHeadroom < minTTFTH { + minTTFTH = p.TTFTHeadroom + } + if p.TTFTHeadroom > maxTTFTH { + maxTTFTH = p.TTFTHeadroom + } + } + return +} + +func (s *SLOAwareRouter) calculateWeightedChoices( + ctx context.Context, + candidates []podPredictionResult, + minTPOTH, maxTPOTH, minTTFTH, maxTTFTH float64, +) ([]choice, int) { + logger := log.FromContext(ctx) + tpotRange := maxTPOTH - minTPOTH + ttftRange := maxTTFTH - minTTFTH + + // Precompute blend weights (renormalize if user sets both to 0) + alpha := HeadroomTTFTWeight + beta := HeadroomTPOTWeight + if alpha+beta <= 0 { + alpha = 1.0 + beta = 0.0 + } + sum := alpha + beta + alpha /= sum + beta /= sum + + logger.V(logutil.DEBUG).Info("Positive headroom normalization ranges", + "minTPOTHeadroom", minTPOTH, "maxTPOTHeadroom", maxTPOTH, + "minTTFTHeadroom", minTTFTH, "maxTTFTHeadroom", maxTTFTH, + "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) + + weightedChoices := make([]choice, 0, len(candidates)) + total := 0 + + for _, p := range candidates { + // Normalize to [0,1] within the cohort + nTPOTH := 0.5 + if tpotRange > eps { + nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) + } + nTTFTH := 0.5 + if ttftRange > eps { + nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) + } + + // Blend: larger combined -> "safer"; smaller -> "tighter packing" + combined := alpha*nTTFTH + beta*nTPOTH + + // Map to integer weights + var w int + switch s.headroomStrategy { + case headroomStrategyLeast: + // prefer smaller combined headroom (pack closer to limits) + w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 + case headroomStrategyMost: + // prefer larger combined headroom (more conservative / spread) + w = int(combined*float64(wMax-minWeight)) + minWeight + 1 + default: + // Fallback to least + w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 + } + + weightedChoices = append(weightedChoices, choice{podName: p.Pod, weight: w}) + total += w + + logger.V(logutil.TRACE).Info("Positive headroom blended weight", + "pod", p.Pod.GetPod().String(), + "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, + "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, + "combined", combined, "weight", w) + } + return weightedChoices, total +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go new file mode 100644 index 0000000000..03844543f3 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go @@ -0,0 +1,57 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + +type headroomStrategy string + +type choice struct { + podName schedulingtypes.Pod + weight int +} + +const ( + // headroomStrategyLeast prioritizes pods with least positive headroom (better packing) + headroomStrategyLeast headroomStrategy = "least" + // headroomStrategyMost prioritizes pods with most positive headroom (more conservative) + headroomStrategyMost headroomStrategy = "most" + + headroomStrategyCompositeLeast headroomStrategy = "composite-least" + headroomStrategyCompositeMost headroomStrategy = "composite-most" + headroomStrategyCompositeOnly headroomStrategy = "composite-only" + + // TTFT header string + ttftSLOHeaderKey = "x-slo-ttft-ms" + // TPOT header string + tpotSLOHeaderKey = "x-slo-tpot-ms" +) + +const ( + SLOAwareRouterPluginType = "slo-aware-routing" + eps = 1e-9 + wMax = 100 + minWeight = 1 +) + +type podSelectionMode string + +const ( + podSelectionLinear podSelectionMode = "linear" // weighted-random (current behavior) + podSelectionMax podSelectionMode = "max" // pick argmax weight +) diff --git a/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go new file mode 100644 index 0000000000..c4bd638e65 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go @@ -0,0 +1,152 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package profile + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + SLOAwareProfileHandlerType = "slo-aware-profile-handler" + DefaultProfileName = "default" + PrefixProfileName = "prefix" + SLOProfileName = "slo" + + // Boolean header string for whether to use predictor based scheduling + PreictionBasedSchedulingHeaderKey = "x-prediction-based-scheduling" +) + +// compile-time type assertion +var _ framework.ProfileHandler = &SLOAwareProfileHandler{} + +// SLOAwareProfileHandlerFactory defines the factory function for SLOAwareProfileHandler. +func SLOAwareProfileHandlerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewSLOAwareProfileHandler().WithName(name), nil +} + +// NewSLOAwareProfileHandler initializes a new SLOAwareProfileHandler and returns its pointer. +func NewSLOAwareProfileHandler() *SLOAwareProfileHandler { + return &SLOAwareProfileHandler{ + typedName: plugins.TypedName{Type: SLOAwareProfileHandlerType, Name: SLOAwareProfileHandlerType}, + } +} + +// SLOAwareProfileHandler handles two profiles: the default profile and the SLO profile. +// When the request has PredictorBasedScheduling=true, it uses the SLO profile result to select +// the destination pod. Otherwise, it uses the default profile result. +type SLOAwareProfileHandler struct { + typedName plugins.TypedName +} + +// TypedName returns the type and name tuple of this plugin instance. +func (h *SLOAwareProfileHandler) TypedName() plugins.TypedName { + return h.typedName +} + +// WithName sets the name of the profile handler. +func (h *SLOAwareProfileHandler) WithName(name string) *SLOAwareProfileHandler { + h.typedName.Name = name + return h +} + +// Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the +// previously executed cycles along with their results. +func (h *SLOAwareProfileHandler) Pick(_ context.Context, _ *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, + profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { + if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call + return map[string]*framework.SchedulerProfile{} + } + + if _, executed := profileResults[PrefixProfileName]; !executed { + // if prefix profile was not executed yet, first let the scheduler run the decode profile + return map[string]*framework.SchedulerProfile{ + PrefixProfileName: profiles[PrefixProfileName], + } + } + // otherwise, prefix was already executed. + + // return all profiles except prefix. + profilesToRun := make(map[string]*framework.SchedulerProfile) + for name, profile := range profiles { + if name != PrefixProfileName { + profilesToRun[name] = profile + } + } + return profilesToRun +} + +// ProcessResults handles the outcome of the profile runs after all profiles ran. +// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the +// key of the primary profile that should be used to get the request selected destination. +// When a profile run fails, its result in the profileResults map is nil. +func (h *SLOAwareProfileHandler) ProcessResults(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { + + if len(profileResults) < 2 { + return nil, errors.New("SLOAwareProfileHandler requires at least two profiles to operate") + } + + predictorBasedScheduling, err := parseBoolHeader(*request, PreictionBasedSchedulingHeaderKey) + if err != nil { + return nil, fmt.Errorf("error parsing predictorBasedScheduling from header failed to choose scheduling profile: x-prediction-based-scheduling must be a bool: %v", err) + } + + if predictorBasedScheduling { // TODO grab header directly from request.Headers instead of request field + if profileResults[SLOProfileName] == nil { // there was an error while running the SLO profile + return nil, fmt.Errorf("failed to run scheduler profile '%s'", SLOProfileName) + } + return &types.SchedulingResult{ + ProfileResults: profileResults, + PrimaryProfileName: SLOProfileName, + }, nil + } + + if profileResults[DefaultProfileName] == nil { // there was an error while running the default profile + return nil, fmt.Errorf("failed to run scheduler profile '%s'", DefaultProfileName) + } + + return &types.SchedulingResult{ + ProfileResults: profileResults, + PrimaryProfileName: DefaultProfileName, + }, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request types.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, fmt.Errorf("must be a bool: %v", headerName) + } + + // 3. Return the successfully parsed value + return parsedBool, nil +}