diff --git a/pkg/plugins/scorer/prefix_cache_tracking.go b/pkg/plugins/scorer/prefix_cache_tracking.go index e6ce0a40..bea0dc14 100644 --- a/pkg/plugins/scorer/prefix_cache_tracking.go +++ b/pkg/plugins/scorer/prefix_cache_tracking.go @@ -16,6 +16,11 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +// PodScorer defines the interface for scoring pods based on KV-cache state. +type PodScorer interface { + GetPodScores(ctx context.Context, prompt, model string, pods []string) (map[string]int, error) +} + // PrefixCacheTrackingConfig holds the configuration for the // PrefixCacheTrackingScorer. type PrefixCacheTrackingConfig struct { @@ -84,9 +89,19 @@ func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTra return &PrefixCacheTrackingScorer{ typedName: plugins.TypedName{Type: prefix.PrefixCachePluginType}, kvCacheIndexer: kvCacheIndexer, + podScorer: kvCacheIndexer, }, nil } +// NewWithPodScorer creates a new PrefixCacheTrackingScorer with a custom PodScorer. +// This is mainly used for testing to inject mock dependencies. +func NewWithPodScorer(podScorer PodScorer) *PrefixCacheTrackingScorer { + return &PrefixCacheTrackingScorer{ + typedName: plugins.TypedName{Type: prefix.PrefixCachePluginType}, + podScorer: podScorer, + } +} + // PrefixCacheTrackingScorer implements the framework.Scorer interface. // The scorer implements the `cache_tracking` mode of the prefix cache plugin. // It uses the `kvcache.Indexer` to score pods based on the KV-cache index @@ -95,6 +110,7 @@ func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTra type PrefixCacheTrackingScorer struct { typedName plugins.TypedName kvCacheIndexer *kvcache.Indexer + podScorer PodScorer } // TypedName returns the typed name of the plugin. @@ -114,13 +130,13 @@ func (s *PrefixCacheTrackingScorer) Score(ctx context.Context, _ *types.CycleSta loggerDebug := log.FromContext(ctx).WithName(s.typedName.String()).V(logutil.DEBUG) if request == nil { loggerDebug.Info("Request is nil, skipping scoring") - return nil + return make(map[types.Pod]float64) } - scores, err := s.kvCacheIndexer.GetPodScores(ctx, request.Prompt, request.TargetModel, nil) + scores, err := s.podScorer.GetPodScores(ctx, request.Prompt, request.TargetModel, nil) if err != nil { loggerDebug.Error(err, "Failed to get pod scores") - return nil + return make(map[types.Pod]float64) } loggerDebug.Info("Got pod scores", "scores", scores) diff --git a/pkg/plugins/scorer/prefix_cache_tracking_test.go b/pkg/plugins/scorer/prefix_cache_tracking_test.go new file mode 100644 index 00000000..68561577 --- /dev/null +++ b/pkg/plugins/scorer/prefix_cache_tracking_test.go @@ -0,0 +1,136 @@ +package scorer_test + +import ( + "context" + "errors" + + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" + "github.com/stretchr/testify/require" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +// mockPodScorer is a mock implementation of the scorer.PodScorer interface for testing. +type mockPodScorer struct { + scores map[string]int + err error +} + +func (m *mockPodScorer) GetPodScores(_ context.Context, _, _ string, _ []string) (map[string]int, error) { + if m.err != nil { + return nil, m.err + } + return m.scores, nil +} + +func TestPrefixCacheTracking_Score(t *testing.T) { + testcases := []struct { + name string + pods []types.Pod + request *types.LLMRequest + mockScores map[string]int + mockError error + wantScoresByAddress map[string]float64 // Use address as key instead of Pod objects + }{ + { + name: "test normalized scores", + pods: []types.Pod{ + &types.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, + Address: "10.0.0.1:8080", + }, + MetricsState: &backendmetrics.MetricsState{ + WaitingQueueSize: 0, + }, + }, + &types.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, + Address: "10.0.0.2:8080", + }, + MetricsState: &backendmetrics.MetricsState{ + WaitingQueueSize: 1, + }, + }, + &types.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, + Address: "10.0.0.3:8080", + }, + MetricsState: &backendmetrics.MetricsState{ + WaitingQueueSize: 2, + }, + }, + }, + request: &types.LLMRequest{ + TargetModel: "gpt-4", + Prompt: "what is meaning of life?", + }, + mockScores: map[string]int{ + "10.0.0.1:8080": 10, + "10.0.0.2:8080": 20, + "10.0.0.3:8080": 30, + }, + wantScoresByAddress: map[string]float64{ + "10.0.0.1:8080": 0.0, // (10-10)/(30-10) = 0.0 + "10.0.0.2:8080": 0.5, // (20-10)/(30-10) = 0.5 + "10.0.0.3:8080": 1.0, // (30-10)/(30-10) = 1.0 + }, + }, + { + name: "test nil request", + pods: []types.Pod{ + &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + }, + }, + request: nil, + wantScoresByAddress: make(map[string]float64), // empty map instead of nil + }, + { + name: "test pod scorer error", + pods: []types.Pod{ + &types.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, + Address: "10.0.0.1:8080", + }, + }, + }, + request: &types.LLMRequest{ + TargetModel: "gpt-4", + Prompt: "test prompt", + }, + mockError: errors.New("test error"), + wantScoresByAddress: make(map[string]float64), // empty map instead of nil + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + mockScorer := &mockPodScorer{ + scores: tt.mockScores, + err: tt.mockError, + } + prefixCacheScorer := scorer.NewWithPodScorer(mockScorer) + require.NotNil(t, prefixCacheScorer) + got := prefixCacheScorer.Score(context.Background(), nil, tt.request, tt.pods) + // Convert the result to address-based map for easier comparison + gotByAddress := make(map[string]float64) + for pod, score := range got { + if podMetrics, ok := pod.(*types.PodMetrics); ok && podMetrics.GetPod() != nil { + gotByAddress[podMetrics.GetPod().Address] = score + } + } + if diff := cmp.Diff(tt.wantScoresByAddress, gotByAddress); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + }) + } +}