diff --git a/pkg/epp/framework/interface/datalayer/endpoint.go b/pkg/epp/framework/interface/datalayer/endpoint.go index 6b78e318c..e6370881a 100644 --- a/pkg/epp/framework/interface/datalayer/endpoint.go +++ b/pkg/epp/framework/interface/datalayer/endpoint.go @@ -90,6 +90,21 @@ func (srv *ModelServer) GetAttributes() AttributeMap { return srv.attributes } +// Get retrieves an attribute value by key, forwarding to the underlying AttributeMap. +func (srv *ModelServer) Get(key string) (Cloneable, bool) { + return srv.attributes.Get(key) +} + +// Put stores an attribute value by key, forwarding to the underlying AttributeMap. +func (srv *ModelServer) Put(key string, value Cloneable) { + srv.attributes.Put(key, value) +} + +// Keys returns all attribute keys, forwarding to the underlying AttributeMap. +func (srv *ModelServer) Keys() []string { + return srv.attributes.Keys() +} + func (srv *ModelServer) Clone() *ModelServer { clone := &ModelServer{ attributes: srv.attributes.Clone(), diff --git a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper.go b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper.go index db08d9748..89d00b683 100644 --- a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper.go +++ b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper.go @@ -28,6 +28,7 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + framework "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" @@ -120,7 +121,9 @@ func processFirstTokenForLatencyPrediction( ctx context.Context, predictor latencypredictor.PredictorInterface, streamingMode bool, + requestBuilder PredictionRequestBuilder, predictedLatencyCtx *predictedLatencyCtx, + pod framework.Endpoint, now time.Time, samplingMean float64, maxSampledTokens int, @@ -139,7 +142,7 @@ func processFirstTokenForLatencyPrediction( targetPod := predictedLatencyCtx.targetMetadata prefixCacheScore := predictedLatencyCtx.prefixCacheScoresForEndpoints[targetPod.NamespacedName.Name] logger.V(logutil.DEBUG).Info("Recording TTFT training data", "ttft_ms", predictedLatencyCtx.ttft, "prefixCacheScore", prefixCacheScore) - recordTTFTTrainingData(ctx, predictor, predictedLatencyCtx, m, now, prefixCacheScore) + recordTTFTTrainingData(ctx, predictor, requestBuilder, predictedLatencyCtx, m, pod, now, prefixCacheScore) if streamingMode { predictFirstTPOT(ctx, predictor, predictedLatencyCtx) @@ -163,24 +166,26 @@ func initializeSampler(ctx context.Context, predictedLatencyCtx *predictedLatenc func recordTTFTTrainingData( ctx context.Context, predictor latencypredictor.PredictorInterface, + requestBuilder PredictionRequestBuilder, predictedLatencyCtx *predictedLatencyCtx, m *fwkdl.Metrics, + pod framework.Endpoint, now time.Time, prefixCacheScore float64, ) { logger := log.FromContext(ctx) - // Train TTFT - entry := latencypredictor.TrainingEntry{ - KVCachePercentage: m.KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)), - ActualTTFT: predictedLatencyCtx.ttft, - ActualTPOT: 0, - Timestamp: now, - NumRequestWaiting: m.WaitingQueueSize, - NumRequestRunning: m.RunningRequestsSize, - NumTokensGenerated: 0, - PrefixCacheScore: prefixCacheScore, - } + // Build training entry using the builder + entry := requestBuilder.BuildTrainingEntry( + ctx, + pod, + m, + predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt, + predictedLatencyCtx.ttft, + 0, // TTFT training + now, + 0, + prefixCacheScore, + ) if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") } @@ -227,7 +232,9 @@ func predictFirstTPOT( func processTokenForLatencyPrediction( ctx context.Context, predictor latencypredictor.PredictorInterface, + requestBuilder PredictionRequestBuilder, predictedLatencyCtx *predictedLatencyCtx, + pod framework.Endpoint, now time.Time, samplingMean float64, maxSampledTokens int, @@ -257,18 +264,18 @@ func processTokenForLatencyPrediction( "error", err) return } - // Record actual TPOT - entry := latencypredictor.TrainingEntry{ - KVCachePercentage: m.KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)), - ActualTTFT: 0, - ActualTPOT: latencyMs, - Timestamp: now, - NumRequestWaiting: m.WaitingQueueSize, - NumRequestRunning: m.RunningRequestsSize, - NumTokensGenerated: predictedLatencyCtx.generatedTokenCount - 1, - PrefixCacheScore: 0, // TPOT does not use prefix cache score - } + // Record actual TPOT using builder + entry := requestBuilder.BuildTrainingEntry( + ctx, + pod, + m, + predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt, + 0, // TTFT not recorded for TPOT + latencyMs, + now, + predictedLatencyCtx.generatedTokenCount-1, + 0, // TPOT does not use prefix cache score + ) if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { logger.V(logutil.DEBUG).Error(err, "record TPOT training failed") } @@ -312,6 +319,8 @@ func bulkPredictWithMetrics( ctx context.Context, predictor latencypredictor.PredictorInterface, metricsStates []*fwkdl.Metrics, + requestBuilder PredictionRequestBuilder, + pods []framework.Endpoint, prompts []string, generatedTokenCounts []int, prefixCacheScores []float64, @@ -319,9 +328,9 @@ func bulkPredictWithMetrics( logger := log.FromContext(ctx) // Validate input lengths - if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) { - return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d", - len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores)) + if len(pods) != len(metricsStates) || len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) { + return nil, fmt.Errorf("input slice lengths must match: pods=%d, metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d", + len(pods), len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores)) } if len(metricsStates) == 0 { @@ -335,17 +344,17 @@ func bulkPredictWithMetrics( } } - // Build bulk prediction requests + // Build bulk prediction requests using the builder bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates)) for i := range metricsStates { - bulkRequests[i] = latencypredictor.PredictionRequest{ - KVCachePercentage: metricsStates[i].KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(prompts[i])), - NumRequestWaiting: metricsStates[i].WaitingQueueSize, - NumRequestRunning: metricsStates[i].RunningRequestsSize, - NumTokensGenerated: generatedTokenCounts[i], - PrefixCacheScore: prefixCacheScores[i], - } + bulkRequests[i] = requestBuilder.BuildPredictionRequest( + ctx, + pods[i], + metricsStates[i], + prompts[i], + generatedTokenCounts[i], + prefixCacheScores[i], + ) } // Perform bulk prediction diff --git a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper_test.go b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper_test.go index bef82920f..71f12f4c8 100644 --- a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper_test.go +++ b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper_test.go @@ -23,7 +23,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/types" fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) @@ -39,11 +41,20 @@ func TestBulkPredictWithMetrics(t *testing.T) { {KVCacheUsagePercent: 0.5}, {KVCacheUsagePercent: 0.6}, } + requestBuilder := &DefaultPredictionRequestBuilder{} + pods := []schedulingtypes.Endpoint{ + fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + }, nil), + fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod2"}, + }, nil), + } prompts := []string{"prompt1", "prompt2"} generatedTokenCounts := []int{1, 1} prefixCacheScores := []float64{0.0, 0.0} - results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores) assert.NoError(t, err) assert.Len(t, results, 2) @@ -61,11 +72,17 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) { metricsStates := []*fwkdl.Metrics{ {KVCacheUsagePercent: 0.5}, } + requestBuilder := &DefaultPredictionRequestBuilder{} + pods := []schedulingtypes.Endpoint{ + fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + }, nil), + } prompts := []string{"prompt1"} generatedTokenCounts := []int{1} prefixCacheScores := []float64{0.0} - results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores) assert.Error(t, err) assert.Nil(t, results) @@ -74,11 +91,17 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) { func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) { mockPredictor := &mockPredictor{} metricsStates := []*fwkdl.Metrics{{}} + requestBuilder := &DefaultPredictionRequestBuilder{} + pods := []schedulingtypes.Endpoint{ + fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + }, nil), + } prompts := []string{"prompt1", "prompt2"} // Mismatch length generatedTokenCounts := []int{1} prefixCacheScores := []float64{0.0} - results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores) assert.Error(t, err) assert.Nil(t, results) @@ -88,11 +111,17 @@ func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) { func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) { mockPredictor := &mockPredictor{} metricsStates := []*fwkdl.Metrics{nil} // Nil metrics state + requestBuilder := &DefaultPredictionRequestBuilder{} + pods := []schedulingtypes.Endpoint{ + fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + }, nil), + } prompts := []string{"prompt1"} generatedTokenCounts := []int{1} prefixCacheScores := []float64{0.0} - results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores) assert.Error(t, err) assert.Nil(t, results) diff --git a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/prediction.go b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/prediction.go index 71fe1fa60..3e7ddabb0 100644 --- a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/prediction.go +++ b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/prediction.go @@ -66,7 +66,7 @@ func (s *PredictedLatency) generatePredictions(ctx context.Context, request *sch } // Bulk predict - bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, s.requestBuilder, candidateEndpoints, prompts, generatedTokenCounts, prefixCacheScores) if err != nil { logger.V(logutil.DEBUG).Error(err, "Bulk prediction failed") return nil, err diff --git a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks.go b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks.go index cdfceff48..a914be8cd 100644 --- a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks.go +++ b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks.go @@ -98,6 +98,57 @@ func (s *PredictedLatency) deletePredictedLatencyContextForRequest(request *sche s.sloContextStore.Delete(id) } +// GetSchedulingResultForRequest returns the scheduling result for a request. +// This is exposed to allow wrapper implementations (e.g., P/D-aware routers) +// to access scheduling information for custom hook logic. +func (s *PredictedLatency) GetSchedulingResultForRequest(request *schedulingtypes.LLMRequest) (*schedulingtypes.SchedulingResult, error) { + predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request) + if err != nil { + return nil, err + } + return predictedLatencyCtx.schedulingResult, nil +} + +// GetLastSeenMetricsForRequest returns the last seen metrics for all profiles in a request. +// This is exposed to allow wrapper implementations to access metrics for custom training logic. +func (s *PredictedLatency) GetLastSeenMetricsForRequest(request *schedulingtypes.LLMRequest) (map[string]*fwkdl.Metrics, error) { + predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request) + if err != nil { + return nil, err + } + return predictedLatencyCtx.lastSeenMetrics, nil +} + +// GetPrefixCacheScoresForRequest returns the prefix cache scores for all pods in a request. +func (s *PredictedLatency) GetPrefixCacheScoresForRequest(request *schedulingtypes.LLMRequest) (map[string]float64, error) { + predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request) + if err != nil { + return nil, err + } + return predictedLatencyCtx.prefixCacheScoresForEndpoints, nil +} + +// GetRequestPrompt returns the prompt for a request. +func (s *PredictedLatency) GetRequestPrompt(request *schedulingtypes.LLMRequest) (string, error) { + predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request) + if err != nil { + return "", err + } + return predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt, nil +} + +// GetRequestBuilder returns the PredictionRequestBuilder used by this router. +// This allows wrappers to use the same builder for consistency. +func (s *PredictedLatency) GetRequestBuilder() PredictionRequestBuilder { + return s.requestBuilder +} + +// GetLatencyPredictor returns the latency predictor client. +// This allows wrappers to record training data using the same predictor. +func (s *PredictedLatency) GetLatencyPredictor() interface{} { + return s.latencypredictor +} + // --- RequestControl Hooks --- func (t *PredictedLatency) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) { @@ -186,10 +237,16 @@ func (t *PredictedLatency) ResponseStreaming(ctx context.Context, request *sched return } + // Create a schedulingtypes.Endpoint wrapper for the metadata + podWrapper := fwkdl.NewEndpoint( + targetMetadata, + predictedLatencyCtx.lastSeenMetrics[predictedLatencyCtx.schedulingResult.PrimaryProfileName], + ) + if predictedLatencyCtx.ttft == 0 { - processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens) + processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens) } else { - processTokenForLatencyPrediction(ctx, t.latencypredictor, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens) + processTokenForLatencyPrediction(ctx, t.latencypredictor, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens) } } @@ -213,7 +270,12 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu } now := time.Now() if !t.config.StreamingMode { - processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens) + // Create a schedulingtypes.Endpoint wrapper for non-streaming responses + podWrapper := fwkdl.NewEndpoint( + targetMetadata, + predictedLatencyCtx.lastSeenMetrics[predictedLatencyCtx.schedulingResult.PrimaryProfileName], + ) + processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens) } if predictedLatencyCtx.ttft > 0 { diff --git a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks_test.go b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks_test.go index 8a310c900..631713479 100644 --- a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks_test.go +++ b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks_test.go @@ -62,6 +62,7 @@ func createTestRouter() *PredictedLatency { sloContextStore: sync.Map{}, runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue), latencypredictor: nil, + requestBuilder: &DefaultPredictionRequestBuilder{}, config: DefaultConfig, } } diff --git a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/scorer.go b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/scorer.go index f5947b581..474d51336 100644 --- a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/scorer.go +++ b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/scorer.go @@ -39,6 +39,7 @@ import ( type PredictedLatency struct { typedName plugin.TypedName latencypredictor latencypredictor.PredictorInterface + requestBuilder PredictionRequestBuilder runningRequestLists map[types.NamespacedName]*requestPriorityQueue sloContextStore sync.Map // map[string]*SLORequestContext headroomStrategy headroomStrategy @@ -65,6 +66,11 @@ type Config struct { AffinityGateTauGlobal float64 `json:"affinityGateTauGlobal,omitempty"` SelectionMode string `json:"selectionMode,omitempty"` StreamingMode bool `json:"streamingMode,omitempty"` + + // RequestBuilder allows customization of prediction and training request construction. + // This field is not serialized and must be set programmatically. + // If nil, defaults to DefaultPredictionRequestBuilder. + RequestBuilder PredictionRequestBuilder `json:"-"` } var DefaultConfig = Config{ @@ -95,6 +101,11 @@ func PredictedLatencyFactory(name string, rawParameters json.RawMessage, handle } } + // Use provided builder or default to DefaultPredictionRequestBuilder + if parameters.RequestBuilder == nil { + parameters.RequestBuilder = &DefaultPredictionRequestBuilder{} + } + if err := parameters.validate(); err != nil { return nil, fmt.Errorf("invalid PredictedLatency config: %w", err) } @@ -157,9 +168,16 @@ func NewPredictedLatency(config Config, predictor latencypredictor.PredictorInte strategy = headroomStrategyLeast } + // Ensure requestBuilder is set + requestBuilder := config.RequestBuilder + if requestBuilder == nil { + requestBuilder = &DefaultPredictionRequestBuilder{} + } + return &PredictedLatency{ typedName: plugin.TypedName{Type: PredictedLatencyPluginType, Name: PredictedLatencyPluginType}, latencypredictor: predictor, + requestBuilder: requestBuilder, runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue), sloContextStore: sync.Map{}, headroomStrategy: strategy, diff --git a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/types.go b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/types.go index 5976751ad..bc7a06fae 100644 --- a/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/types.go +++ b/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/types.go @@ -17,7 +17,15 @@ limitations under the License. // Package requestcontrol contains helpers to decouple latency-predictor logic. package predictedlatency -import schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" +import ( + "context" + "strings" + "time" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) type headroomStrategy string @@ -55,3 +63,81 @@ const ( podSelectionLinear podSelectionMode = "linear" // weighted-random (current behavior) podSelectionMax podSelectionMode = "max" // pick argmax weight ) + +// PredictionRequestBuilder constructs prediction and training requests with optional customization. +// This interface allows different implementations to customize how prediction requests are built, +// for example to add pod type information for disaggregated serving scenarios. +type PredictionRequestBuilder interface { + // BuildPredictionRequest constructs a prediction request for a pod + BuildPredictionRequest( + ctx context.Context, + pod schedulingtypes.Endpoint, + metrics *datalayer.Metrics, + prompt string, + generatedTokens int, + prefixCacheScore float64, + ) latencypredictor.PredictionRequest + + // BuildTrainingEntry constructs a training entry for a pod + BuildTrainingEntry( + ctx context.Context, + pod schedulingtypes.Endpoint, + metrics *datalayer.Metrics, + prompt string, + actualTTFT float64, + actualTPOT float64, + timestamp time.Time, + generatedTokens int, + prefixCacheScore float64, + ) latencypredictor.TrainingEntry +} + +// DefaultPredictionRequestBuilder provides the default monolithic behavior for building prediction requests. +// This implementation leaves PodType empty, suitable for monolithic (non-disaggregated) deployments. +type DefaultPredictionRequestBuilder struct{} + +// BuildPredictionRequest constructs a standard prediction request without pod type information +func (b *DefaultPredictionRequestBuilder) BuildPredictionRequest( + ctx context.Context, + pod schedulingtypes.Endpoint, + metrics *datalayer.Metrics, + prompt string, + generatedTokens int, + prefixCacheScore float64, +) latencypredictor.PredictionRequest { + return latencypredictor.PredictionRequest{ + KVCachePercentage: metrics.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompt)), // Simple word-based tokenization + NumRequestWaiting: metrics.WaitingQueueSize, + NumRequestRunning: metrics.RunningRequestsSize, + NumTokensGenerated: generatedTokens, + PrefixCacheScore: prefixCacheScore, + PodType: "", // Empty for monolithic deployments + } +} + +// BuildTrainingEntry constructs a standard training entry without pod type information +func (b *DefaultPredictionRequestBuilder) BuildTrainingEntry( + ctx context.Context, + pod schedulingtypes.Endpoint, + metrics *datalayer.Metrics, + prompt string, + actualTTFT float64, + actualTPOT float64, + timestamp time.Time, + generatedTokens int, + prefixCacheScore float64, +) latencypredictor.TrainingEntry { + return latencypredictor.TrainingEntry{ + KVCachePercentage: metrics.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompt)), // Simple word-based tokenization + ActualTTFT: actualTTFT, + ActualTPOT: actualTPOT, + Timestamp: timestamp, + NumRequestWaiting: metrics.WaitingQueueSize, + NumRequestRunning: metrics.RunningRequestsSize, + NumTokensGenerated: generatedTokens, + PrefixCacheScore: prefixCacheScore, + PodType: "", // Empty for monolithic deployments + } +}