Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pkg/epp/framework/interface/datalayer/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ func (srv *ModelServer) GetAttributes() AttributeMap {
return srv.attributes
}

// Get retrieves an attribute value by key, forwarding to the underlying AttributeMap.
func (srv *ModelServer) Get(key string) (Cloneable, bool) {
return srv.attributes.Get(key)
}

// Put stores an attribute value by key, forwarding to the underlying AttributeMap.
func (srv *ModelServer) Put(key string, value Cloneable) {
srv.attributes.Put(key, value)
}

// Keys returns all attribute keys, forwarding to the underlying AttributeMap.
func (srv *ModelServer) Keys() []string {
return srv.attributes.Keys()
}

func (srv *ModelServer) Clone() *ModelServer {
clone := &ModelServer{
attributes: srv.attributes.Clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
framework "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
Expand Down Expand Up @@ -120,7 +121,9 @@ func processFirstTokenForLatencyPrediction(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
streamingMode bool,
requestBuilder PredictionRequestBuilder,
predictedLatencyCtx *predictedLatencyCtx,
pod framework.Endpoint,
now time.Time,
samplingMean float64,
maxSampledTokens int,
Expand All @@ -139,7 +142,7 @@ func processFirstTokenForLatencyPrediction(
targetPod := predictedLatencyCtx.targetMetadata
prefixCacheScore := predictedLatencyCtx.prefixCacheScoresForEndpoints[targetPod.NamespacedName.Name]
logger.V(logutil.DEBUG).Info("Recording TTFT training data", "ttft_ms", predictedLatencyCtx.ttft, "prefixCacheScore", prefixCacheScore)
recordTTFTTrainingData(ctx, predictor, predictedLatencyCtx, m, now, prefixCacheScore)
recordTTFTTrainingData(ctx, predictor, requestBuilder, predictedLatencyCtx, m, pod, now, prefixCacheScore)

if streamingMode {
predictFirstTPOT(ctx, predictor, predictedLatencyCtx)
Expand All @@ -163,24 +166,26 @@ func initializeSampler(ctx context.Context, predictedLatencyCtx *predictedLatenc
func recordTTFTTrainingData(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
requestBuilder PredictionRequestBuilder,
predictedLatencyCtx *predictedLatencyCtx,
m *fwkdl.Metrics,
pod framework.Endpoint,
now time.Time,
prefixCacheScore float64,
) {
logger := log.FromContext(ctx)
// Train TTFT
entry := latencypredictor.TrainingEntry{
KVCachePercentage: m.KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
ActualTTFT: predictedLatencyCtx.ttft,
ActualTPOT: 0,
Timestamp: now,
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningRequestsSize,
NumTokensGenerated: 0,
PrefixCacheScore: prefixCacheScore,
}
// Build training entry using the builder
entry := requestBuilder.BuildTrainingEntry(
ctx,
pod,
m,
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
predictedLatencyCtx.ttft,
0, // TTFT training
now,
0,
prefixCacheScore,
)
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
logger.V(logutil.DEBUG).Error(err, "record TTFT training failed")
}
Expand Down Expand Up @@ -227,7 +232,9 @@ func predictFirstTPOT(
func processTokenForLatencyPrediction(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
requestBuilder PredictionRequestBuilder,
predictedLatencyCtx *predictedLatencyCtx,
pod framework.Endpoint,
now time.Time,
samplingMean float64,
maxSampledTokens int,
Expand Down Expand Up @@ -257,18 +264,18 @@ func processTokenForLatencyPrediction(
"error", err)
return
}
// Record actual TPOT
entry := latencypredictor.TrainingEntry{
KVCachePercentage: m.KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
ActualTTFT: 0,
ActualTPOT: latencyMs,
Timestamp: now,
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningRequestsSize,
NumTokensGenerated: predictedLatencyCtx.generatedTokenCount - 1,
PrefixCacheScore: 0, // TPOT does not use prefix cache score
}
// Record actual TPOT using builder
entry := requestBuilder.BuildTrainingEntry(
ctx,
pod,
m,
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
0, // TTFT not recorded for TPOT
latencyMs,
now,
predictedLatencyCtx.generatedTokenCount-1,
0, // TPOT does not use prefix cache score
)
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
logger.V(logutil.DEBUG).Error(err, "record TPOT training failed")
}
Expand Down Expand Up @@ -312,16 +319,18 @@ func bulkPredictWithMetrics(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
metricsStates []*fwkdl.Metrics,
requestBuilder PredictionRequestBuilder,
pods []framework.Endpoint,
prompts []string,
generatedTokenCounts []int,
prefixCacheScores []float64,
) ([]*latencypredictor.PredictionResponse, error) {
logger := log.FromContext(ctx)

// Validate input lengths
if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) {
return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d",
len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores))
if len(pods) != len(metricsStates) || len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) {
return nil, fmt.Errorf("input slice lengths must match: pods=%d, metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d",
len(pods), len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores))
}

if len(metricsStates) == 0 {
Expand All @@ -335,17 +344,17 @@ func bulkPredictWithMetrics(
}
}

// Build bulk prediction requests
// Build bulk prediction requests using the builder
bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates))
for i := range metricsStates {
bulkRequests[i] = latencypredictor.PredictionRequest{
KVCachePercentage: metricsStates[i].KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(prompts[i])),
NumRequestWaiting: metricsStates[i].WaitingQueueSize,
NumRequestRunning: metricsStates[i].RunningRequestsSize,
NumTokensGenerated: generatedTokenCounts[i],
PrefixCacheScore: prefixCacheScores[i],
}
bulkRequests[i] = requestBuilder.BuildPredictionRequest(
ctx,
pods[i],
metricsStates[i],
prompts[i],
generatedTokenCounts[i],
prefixCacheScores[i],
)
}

// Perform bulk prediction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/types"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
)

Expand All @@ -39,11 +41,20 @@ func TestBulkPredictWithMetrics(t *testing.T) {
{KVCacheUsagePercent: 0.5},
{KVCacheUsagePercent: 0.6},
}
requestBuilder := &DefaultPredictionRequestBuilder{}
pods := []schedulingtypes.Endpoint{
fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
}, nil),
fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod2"},
}, nil),
}
prompts := []string{"prompt1", "prompt2"}
generatedTokenCounts := []int{1, 1}
prefixCacheScores := []float64{0.0, 0.0}

results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)

assert.NoError(t, err)
assert.Len(t, results, 2)
Expand All @@ -61,11 +72,17 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
metricsStates := []*fwkdl.Metrics{
{KVCacheUsagePercent: 0.5},
}
requestBuilder := &DefaultPredictionRequestBuilder{}
pods := []schedulingtypes.Endpoint{
fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
}, nil),
}
prompts := []string{"prompt1"}
generatedTokenCounts := []int{1}
prefixCacheScores := []float64{0.0}

results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)

assert.Error(t, err)
assert.Nil(t, results)
Expand All @@ -74,11 +91,17 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
mockPredictor := &mockPredictor{}
metricsStates := []*fwkdl.Metrics{{}}
requestBuilder := &DefaultPredictionRequestBuilder{}
pods := []schedulingtypes.Endpoint{
fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
}, nil),
}
prompts := []string{"prompt1", "prompt2"} // Mismatch length
generatedTokenCounts := []int{1}
prefixCacheScores := []float64{0.0}

results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)

assert.Error(t, err)
assert.Nil(t, results)
Expand All @@ -88,11 +111,17 @@ func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) {
mockPredictor := &mockPredictor{}
metricsStates := []*fwkdl.Metrics{nil} // Nil metrics state
requestBuilder := &DefaultPredictionRequestBuilder{}
pods := []schedulingtypes.Endpoint{
fwkdl.NewEndpoint(&fwkdl.EndpointMetadata{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
}, nil),
}
prompts := []string{"prompt1"}
generatedTokenCounts := []int{1}
prefixCacheScores := []float64{0.0}

results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)

assert.Error(t, err)
assert.Nil(t, results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (s *PredictedLatency) generatePredictions(ctx context.Context, request *sch
}

// Bulk predict
bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, s.requestBuilder, candidateEndpoints, prompts, generatedTokenCounts, prefixCacheScores)
if err != nil {
logger.V(logutil.DEBUG).Error(err, "Bulk prediction failed")
return nil, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,57 @@ func (s *PredictedLatency) deletePredictedLatencyContextForRequest(request *sche
s.sloContextStore.Delete(id)
}

// GetSchedulingResultForRequest returns the scheduling result for a request.
// This is exposed to allow wrapper implementations (e.g., P/D-aware routers)
// to access scheduling information for custom hook logic.
func (s *PredictedLatency) GetSchedulingResultForRequest(request *schedulingtypes.LLMRequest) (*schedulingtypes.SchedulingResult, error) {
predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request)
if err != nil {
return nil, err
}
return predictedLatencyCtx.schedulingResult, nil
}

// GetLastSeenMetricsForRequest returns the last seen metrics for all profiles in a request.
// This is exposed to allow wrapper implementations to access metrics for custom training logic.
func (s *PredictedLatency) GetLastSeenMetricsForRequest(request *schedulingtypes.LLMRequest) (map[string]*fwkdl.Metrics, error) {
predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request)
if err != nil {
return nil, err
}
return predictedLatencyCtx.lastSeenMetrics, nil
}

// GetPrefixCacheScoresForRequest returns the prefix cache scores for all pods in a request.
func (s *PredictedLatency) GetPrefixCacheScoresForRequest(request *schedulingtypes.LLMRequest) (map[string]float64, error) {
predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request)
if err != nil {
return nil, err
}
return predictedLatencyCtx.prefixCacheScoresForEndpoints, nil
}

// GetRequestPrompt returns the prompt for a request.
func (s *PredictedLatency) GetRequestPrompt(request *schedulingtypes.LLMRequest) (string, error) {
predictedLatencyCtx, err := s.getPredictedLatencyContextForRequest(request)
if err != nil {
return "", err
}
return predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt, nil
}

// GetRequestBuilder returns the PredictionRequestBuilder used by this router.
// This allows wrappers to use the same builder for consistency.
func (s *PredictedLatency) GetRequestBuilder() PredictionRequestBuilder {
return s.requestBuilder
}

// GetLatencyPredictor returns the latency predictor client.
// This allows wrappers to record training data using the same predictor.
func (s *PredictedLatency) GetLatencyPredictor() interface{} {
return s.latencypredictor
}

// --- RequestControl Hooks ---

func (t *PredictedLatency) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) {
Expand Down Expand Up @@ -186,10 +237,16 @@ func (t *PredictedLatency) ResponseStreaming(ctx context.Context, request *sched
return
}

// Create a schedulingtypes.Endpoint wrapper for the metadata
podWrapper := fwkdl.NewEndpoint(
targetMetadata,
predictedLatencyCtx.lastSeenMetrics[predictedLatencyCtx.schedulingResult.PrimaryProfileName],
)

if predictedLatencyCtx.ttft == 0 {
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens)
} else {
processTokenForLatencyPrediction(ctx, t.latencypredictor, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
processTokenForLatencyPrediction(ctx, t.latencypredictor, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens)
}

}
Expand All @@ -213,7 +270,12 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
}
now := time.Now()
if !t.config.StreamingMode {
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
// Create a schedulingtypes.Endpoint wrapper for non-streaming responses
podWrapper := fwkdl.NewEndpoint(
targetMetadata,
predictedLatencyCtx.lastSeenMetrics[predictedLatencyCtx.schedulingResult.PrimaryProfileName],
)
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.requestBuilder, predictedLatencyCtx, podWrapper, now, t.config.SamplingMean, t.config.MaxSampledTokens)
}

if predictedLatencyCtx.ttft > 0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func createTestRouter() *PredictedLatency {
sloContextStore: sync.Map{},
runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue),
latencypredictor: nil,
requestBuilder: &DefaultPredictionRequestBuilder{},
config: DefaultConfig,
}
}
Expand Down
Loading