diff --git a/.gitignore b/.gitignore index cc9485b9..dbb7ca44 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,5 @@ go.work.sum # IDE files .idea .vscode + +vendor diff --git a/pkg/plugins/filter/by_label.go b/pkg/plugins/filter/by_label.go index b053b3c2..37bbbe34 100644 --- a/pkg/plugins/filter/by_label.go +++ b/pkg/plugins/filter/by_label.go @@ -11,11 +11,11 @@ import ( ) const ( - // ByLabelFilterType is the type of the ByLabel filter - ByLabelFilterType = "by-label" + // ByLabelType is the type of the ByLabel filter + ByLabelType = "by-label" ) -type byLabelFilterParameters struct { +type byLabelParameters struct { Label string `json:"label"` ValidValues []string `json:"validValues"` AllowsNoLabel bool `json:"allowsNoLabel"` @@ -23,12 +23,12 @@ type byLabelFilterParameters struct { var _ framework.Filter = &ByLabel{} // validate interface conformance -// ByLabelFilterFactory defines the factory function for the ByLabelFilter -func ByLabelFilterFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { - parameters := byLabelFilterParameters{} +// ByLabelFactory defines the factory function for the ByLabel filter. +func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + parameters := byLabelParameters{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { - return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelFilterType, err) + return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelType, err) } } return NewByLabel(name, parameters.Label, parameters.AllowsNoLabel, parameters.ValidValues...), nil @@ -47,7 +47,7 @@ func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues . } return &ByLabel{ - typedName: plugins.TypedName{Type: ByLabelFilterType, Name: name}, + typedName: plugins.TypedName{Type: ByLabelType, Name: name}, labelName: labelName, allowsNoLabel: allowsNoLabel, validValues: validValuesMap, diff --git a/pkg/plugins/filter/by_label_selector.go b/pkg/plugins/filter/by_label_selector.go index 15b13520..98b95d41 100644 --- a/pkg/plugins/filter/by_label_selector.go +++ b/pkg/plugins/filter/by_label_selector.go @@ -14,8 +14,8 @@ import ( ) const ( - // ByLabelSelectorFilterType is the type of the ByLabelsFilter - ByLabelSelectorFilterType = "by-label-selector" + // ByLabelSelectorType is the type of the ByLabelSelector filter + ByLabelSelectorType = "by-label-selector" ) // compile-time type assertion @@ -26,7 +26,7 @@ func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugin parameters := metav1.LabelSelector{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { - return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelSelectorFilterType, err) + return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelSelectorType, err) } } return NewByLabelSelector(name, ¶meters) @@ -44,7 +44,7 @@ func NewByLabelSelector(name string, selector *metav1.LabelSelector) (*ByLabelSe } return &ByLabelSelector{ - typedName: plugins.TypedName{Type: ByLabelSelectorFilterType, Name: name}, + typedName: plugins.TypedName{Type: ByLabelSelectorType, Name: name}, selector: labelSelector, }, nil } diff --git a/pkg/plugins/filter/pd_role.go b/pkg/plugins/filter/pd_role.go new file mode 100644 index 00000000..cc4cf744 --- /dev/null +++ b/pkg/plugins/filter/pd_role.go @@ -0,0 +1,43 @@ +package filter + +import ( + "encoding/json" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" +) + +const ( + // RoleLabel name + RoleLabel = "llm-d.ai/role" + // RolePrefill set for designated prefill workers + RolePrefill = "prefill" + // RoleDecode set for designated decode workers + RoleDecode = "decode" + // RoleBoth set for workers that can act as both prefill and decode + RoleBoth = "both" + + // DecodeRoleType is the type of the DecodeFilter + DecodeRoleType = "decode-filter" + // PrefillRoleType is the type of the PrefillFilter + PrefillRoleType = "prefill-filter" +) + +// PrefillRoleFactory defines the factory function for the Prefill filter. +func PrefillRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewPrefillRole().WithName(name), nil +} + +// NewPrefillRole creates and returns an instance of the Filter configured for prefill role +func NewPrefillRole() *ByLabel { + return NewByLabel(PrefillRoleType, RoleLabel, false, RolePrefill) +} + +// DecodeRoleFactory defines the factory function for the Decode filter. +func DecodeRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewDecodeRole().WithName(name), nil +} + +// NewDecodeRole creates and returns an instance of the Filter configured for decode role +func NewDecodeRole() *ByLabel { + return NewByLabel(DecodeRoleType, RoleLabel, true, RoleDecode, RoleBoth) +} diff --git a/pkg/plugins/filter/pd_role_filter.go b/pkg/plugins/filter/pd_role_filter.go deleted file mode 100644 index 3ac50737..00000000 --- a/pkg/plugins/filter/pd_role_filter.go +++ /dev/null @@ -1,43 +0,0 @@ -package filter - -import ( - "encoding/json" - - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" -) - -const ( - // RoleLabel name - RoleLabel = "llm-d.ai/role" - // RolePrefill set for designated prefill workers - RolePrefill = "prefill" - // RoleDecode set for designated decode workers - RoleDecode = "decode" - // RoleBoth set for workers that can act as both prefill and decode - RoleBoth = "both" - - // DecodeFilterType is the type of the DecodeFilter - DecodeFilterType = "decode-filter" - // PrefillFilterType is the type of the PrefillFilter - PrefillFilterType = "prefill-filter" -) - -// PrefillFilterFactory defines the factory function for the PrefillFilter -func PrefillFilterFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { - return NewPrefillFilter().WithName(name), nil -} - -// NewPrefillFilter creates and returns an instance of the Filter configured for prefill role -func NewPrefillFilter() *ByLabel { - return NewByLabel(PrefillFilterType, RoleLabel, false, RolePrefill) -} - -// DecodeFilterFactory defines the factory function for the DecodeFilter -func DecodeFilterFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { - return NewDecodeFilter().WithName(name), nil -} - -// NewDecodeFilter creates and returns an instance of the Filter configured for decode role -func NewDecodeFilter() *ByLabel { - return NewByLabel(DecodeFilterType, RoleLabel, true, RoleDecode, RoleBoth) -} diff --git a/pkg/plugins/register.go b/pkg/plugins/register.go index bb635726..774e9b27 100644 --- a/pkg/plugins/register.go +++ b/pkg/plugins/register.go @@ -12,14 +12,14 @@ import ( // RegisterAllPlugins registers the factory functions of all plugins in this repository. func RegisterAllPlugins() { - plugins.Register(filter.ByLabelFilterType, filter.ByLabelFilterFactory) - plugins.Register(filter.ByLabelSelectorFilterType, filter.ByLabelSelectorFactory) - plugins.Register(filter.DecodeFilterType, filter.DecodeFilterFactory) - plugins.Register(filter.PrefillFilterType, filter.PrefillFilterFactory) + plugins.Register(filter.ByLabelType, filter.ByLabelFactory) + plugins.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory) + plugins.Register(filter.DecodeRoleType, filter.DecodeRoleFactory) + plugins.Register(filter.PrefillRoleType, filter.PrefillRoleFactory) plugins.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory) plugins.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory) plugins.Register(prefix.PrefixCachePluginType, scorer.PrefixCachePluginFactory) - plugins.Register(scorer.LoadAwareScorerType, scorer.LoadAwareScorerFactory) - plugins.Register(scorer.SessionAffinityScorerType, scorer.SessionAffinityScorerFactory) - plugins.Register(scorer.ActiveRequestScorerType, scorer.ActiveRequestScorerFactory) + plugins.Register(scorer.LoadAwareType, scorer.LoadAwareFactory) + plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory) + plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory) } diff --git a/pkg/plugins/scorer/active_request.go b/pkg/plugins/scorer/active_request.go index 11b2a9ad..f4018d96 100644 --- a/pkg/plugins/scorer/active_request.go +++ b/pkg/plugins/scorer/active_request.go @@ -18,17 +18,17 @@ import ( ) const ( - // ActiveRequestScorerType is the type of the ActiveRequestScorer - ActiveRequestScorerType = "active-request-scorer" + // ActiveRequestType is the type of the ActiveRequest scorer. + ActiveRequestType = "active-request-scorer" // defaultRequestTimeout defines the default timeout for open requests to be // considered stale and removed from the cache. defaultRequestTimeout = 2 * time.Minute ) -// ActiveRequestScorerParameters defines the parameters for the -// ActiveRequestScorer. -type ActiveRequestScorerParameters struct { +// ActiveRequestParameters defines the parameters for the +// ActiveRequest. +type ActiveRequestParameters struct { // RequestTimeout defines the timeout for requests in seconds. // Once the request is "in-flight" for this duration, it is considered to // be timed out and dropped. @@ -48,22 +48,22 @@ func (r *requestEntry) String() string { } // compile-time type assertion -var _ framework.Scorer = &ActiveRequestScorer{} +var _ framework.Scorer = &ActiveRequest{} -// ActiveRequestScorerFactory defines the factory function for the ActiveRequestScorer. -func ActiveRequestScorerFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { - parameters := ActiveRequestScorerParameters{} +// ActiveRequestFactory defines the factory function for the ActiveRequest scorer. +func ActiveRequestFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { + parameters := ActiveRequestParameters{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { - return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", ActiveRequestScorerType, err) + return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", ActiveRequestType, err) } } - return NewActiveRequestScorer(handle.Context(), ¶meters).WithName(name), nil + return NewActiveRequest(handle.Context(), ¶meters).WithName(name), nil } -// NewActiveRequestScorer creates a new ActiveRequestScorer scorer. -func NewActiveRequestScorer(ctx context.Context, params *ActiveRequestScorerParameters) *ActiveRequestScorer { +// NewActiveRequest creates a new ActiveRequest scorer. +func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *ActiveRequest { requestTimeout := defaultRequestTimeout logger := log.FromContext(ctx) @@ -83,8 +83,8 @@ func NewActiveRequestScorer(ctx context.Context, params *ActiveRequestScorerPara ttlcache.WithDisableTouchOnHit[string, *requestEntry](), ) - scorer := &ActiveRequestScorer{ - typedName: plugins.TypedName{Type: ActiveRequestScorerType}, + scorer := &ActiveRequest{ + typedName: plugins.TypedName{Type: ActiveRequestType}, requestCache: requestCache, podCounts: make(map[string]int), mutex: &sync.RWMutex{}, @@ -104,9 +104,9 @@ func NewActiveRequestScorer(ctx context.Context, params *ActiveRequestScorerPara return scorer } -// ActiveRequestScorer keeps track of individual requests being served +// ActiveRequest keeps track of individual requests being served // per pod. -type ActiveRequestScorer struct { +type ActiveRequest struct { typedName plugins.TypedName // requestCache stores individual request entries with unique composite keys (podName.requestID) @@ -118,19 +118,19 @@ type ActiveRequestScorer struct { } // TypedName returns the typed name of the plugin. -func (s *ActiveRequestScorer) TypedName() plugins.TypedName { +func (s *ActiveRequest) TypedName() plugins.TypedName { return s.typedName } // WithName sets the name of the plugin. -func (s *ActiveRequestScorer) WithName(name string) *ActiveRequestScorer { +func (s *ActiveRequest) WithName(name string) *ActiveRequest { s.typedName.Name = name return s } // Score scores the given pods based on the number of active requests // being served by each pod. The score is normalized to a range of 0-1. -func (s *ActiveRequestScorer) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest, +func (s *ActiveRequest) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { scoredPods := make(map[string]int) maxCount := 0 @@ -164,7 +164,7 @@ func (s *ActiveRequestScorer) Score(ctx context.Context, _ *types.CycleState, _ // PreRequest is called before a request is sent to the target pod. // It creates a new request entry in the cache with its own TTL and // increments the pod count for fast lookup. -func (s *ActiveRequestScorer) PreRequest(ctx context.Context, request *types.LLMRequest, +func (s *ActiveRequest) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) { debugLogger := log.FromContext(ctx).V(logutil.DEBUG) @@ -190,9 +190,9 @@ func (s *ActiveRequestScorer) PreRequest(ctx context.Context, request *types.LLM // PostResponse is called after a response is sent to the client. // It removes the specific request entry from the cache and decrements // the pod count. -func (s *ActiveRequestScorer) PostResponse(ctx context.Context, request *types.LLMRequest, +func (s *ActiveRequest) PostResponse(ctx context.Context, request *types.LLMRequest, _ *requestcontrol.Response, targetPod *backend.Pod) { - debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequestScorer.PostResponse") + debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequest.PostResponse") if targetPod == nil { debugLogger.Info("Skipping PostResponse because targetPod is nil") return @@ -209,7 +209,7 @@ func (s *ActiveRequestScorer) PostResponse(ctx context.Context, request *types.L } // incrementPodCount increments the request count for a pod. -func (s *ActiveRequestScorer) incrementPodCount(podName string) { +func (s *ActiveRequest) incrementPodCount(podName string) { s.mutex.Lock() defer s.mutex.Unlock() @@ -218,7 +218,7 @@ func (s *ActiveRequestScorer) incrementPodCount(podName string) { // decrementPodCount decrements the request count for a pod and removes // the entry if count reaches zero. -func (s *ActiveRequestScorer) decrementPodCount(podName string) { +func (s *ActiveRequest) decrementPodCount(podName string) { s.mutex.Lock() defer s.mutex.Unlock() diff --git a/pkg/plugins/scorer/active_request_test.go b/pkg/plugins/scorer/active_request_test.go index 25f8f3c7..72ea0655 100644 --- a/pkg/plugins/scorer/active_request_test.go +++ b/pkg/plugins/scorer/active_request_test.go @@ -35,13 +35,13 @@ func TestActiveRequestScorer_Score(t *testing.T) { tests := []struct { name string - setupCache func(*ActiveRequestScorer) + setupCache func(*ActiveRequest) input []types.Pod wantScores map[types.Pod]float64 }{ { name: "no pods in cache", - setupCache: func(_ *ActiveRequestScorer) { + setupCache: func(_ *ActiveRequest) { // Cache is empty }, input: []types.Pod{podA, podB, podC}, @@ -53,7 +53,7 @@ func TestActiveRequestScorer_Score(t *testing.T) { }, { name: "all pods in cache with different request counts", - setupCache: func(s *ActiveRequestScorer) { + setupCache: func(s *ActiveRequest) { s.mutex.Lock() s.podCounts["default/pod-a"] = 3 s.podCounts["default/pod-b"] = 0 @@ -69,7 +69,7 @@ func TestActiveRequestScorer_Score(t *testing.T) { }, { name: "some pods in cache", - setupCache: func(s *ActiveRequestScorer) { + setupCache: func(s *ActiveRequest) { s.mutex.Lock() s.podCounts["default/pod-a"] = 4 s.podCounts["default/pod-c"] = 1 @@ -87,7 +87,7 @@ func TestActiveRequestScorer_Score(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - scorer := NewActiveRequestScorer(context.Background(), nil) + scorer := NewActiveRequest(context.Background(), nil) test.setupCache(scorer) got := scorer.Score(context.Background(), nil, nil, test.input) @@ -102,7 +102,7 @@ func TestActiveRequestScorer_Score(t *testing.T) { func TestActiveRequestScorer_PreRequest(t *testing.T) { ctx := context.Background() - scorer := NewActiveRequestScorer(ctx, nil) + scorer := NewActiveRequest(ctx, nil) podA := &types.PodMetrics{ Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, @@ -171,7 +171,7 @@ func TestActiveRequestScorer_PreRequest(t *testing.T) { func TestActiveRequestScorer_PostResponse(t *testing.T) { ctx := context.Background() - scorer := NewActiveRequestScorer(ctx, nil) + scorer := NewActiveRequest(ctx, nil) request := &types.LLMRequest{ RequestId: "test-request-1", @@ -228,8 +228,8 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { ctx := context.Background() // Use very short timeout for test - params := &ActiveRequestScorerParameters{RequestTimeout: "1s"} - scorer := NewActiveRequestScorer(ctx, params) // 1 second timeout + params := &ActiveRequestParameters{RequestTimeout: "1s"} + scorer := NewActiveRequest(ctx, params) // 1 second timeout request := &types.LLMRequest{ RequestId: "test-request-ttl", @@ -274,8 +274,8 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { } func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { - params := &ActiveRequestScorerParameters{RequestTimeout: "invalid"} - scorer := NewActiveRequestScorer(context.Background(), params) + params := &ActiveRequestParameters{RequestTimeout: "invalid"} + scorer := NewActiveRequest(context.Background(), params) // Should use default timeout when invalid value is provided if scorer == nil { @@ -284,16 +284,16 @@ func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { } func TestActiveRequestScorer_TypedName(t *testing.T) { - scorer := NewActiveRequestScorer(context.Background(), nil) + scorer := NewActiveRequest(context.Background(), nil) typedName := scorer.TypedName() - if typedName.Type != ActiveRequestScorerType { - t.Errorf("Expected type %s, got %s", ActiveRequestScorerType, typedName.Type) + if typedName.Type != ActiveRequestType { + t.Errorf("Expected type %s, got %s", ActiveRequestType, typedName.Type) } } func TestActiveRequestScorer_WithName(t *testing.T) { - scorer := NewActiveRequestScorer(context.Background(), nil) + scorer := NewActiveRequest(context.Background(), nil) testName := "test-scorer" scorer = scorer.WithName(testName) diff --git a/pkg/plugins/scorer/load_aware_scorer.go b/pkg/plugins/scorer/load_aware.go similarity index 65% rename from pkg/plugins/scorer/load_aware_scorer.go rename to pkg/plugins/scorer/load_aware.go index 0e5a7656..c4f86d0b 100644 --- a/pkg/plugins/scorer/load_aware_scorer.go +++ b/pkg/plugins/scorer/load_aware.go @@ -13,58 +13,58 @@ import ( ) const ( - // LoadAwareScorerType is the type of the LoadAwareScorer - LoadAwareScorerType = "load-aware-scorer" + // LoadAwareType is the type of the LoadAware scorer + LoadAwareType = "load-aware-scorer" // QueueThresholdDefault defines the default queue threshold value QueueThresholdDefault = 128 ) -type loadAwareScorerParameters struct { +type loadAwareParameters struct { Threshold int `json:"threshold"` } // compile-time type assertion -var _ framework.Scorer = &LoadAwareScorer{} +var _ framework.Scorer = &LoadAware{} -// LoadAwareScorerFactory defines the factory function for the LoadAwareScorer -func LoadAwareScorerFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { - parameters := loadAwareScorerParameters{Threshold: QueueThresholdDefault} +// LoadAwareFactory defines the factory function for the LoadAware +func LoadAwareFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { + parameters := loadAwareParameters{Threshold: QueueThresholdDefault} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { - return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", LoadAwareScorerType, err) + return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", LoadAwareType, err) } } - return NewLoadAwareScorer(handle.Context(), parameters.Threshold).WithName(name), nil + return NewLoadAware(handle.Context(), parameters.Threshold).WithName(name), nil } -// NewLoadAwareScorer creates a new load based scorer -func NewLoadAwareScorer(ctx context.Context, queueThreshold int) *LoadAwareScorer { +// NewLoadAware creates a new load based scorer +func NewLoadAware(ctx context.Context, queueThreshold int) *LoadAware { if queueThreshold <= 0 { queueThreshold = QueueThresholdDefault log.FromContext(ctx).V(logutil.DEFAULT).Info(fmt.Sprintf("queueThreshold %d should be positive, using default queue threshold %d", queueThreshold, QueueThresholdDefault)) } - return &LoadAwareScorer{ - typedName: plugins.TypedName{Type: LoadAwareScorerType}, + return &LoadAware{ + typedName: plugins.TypedName{Type: LoadAwareType}, queueThreshold: float64(queueThreshold), } } -// LoadAwareScorer scorer that is based on load -type LoadAwareScorer struct { +// LoadAware scorer that is based on load +type LoadAware struct { typedName plugins.TypedName queueThreshold float64 } // TypedName returns the typed name of the plugin. -func (s *LoadAwareScorer) TypedName() plugins.TypedName { +func (s *LoadAware) TypedName() plugins.TypedName { return s.typedName } // WithName sets the name of the plugin. -func (s *LoadAwareScorer) WithName(name string) *LoadAwareScorer { +func (s *LoadAware) WithName(name string) *LoadAware { s.typedName.Name = name return s } @@ -76,7 +76,7 @@ func (s *LoadAwareScorer) WithName(name string) *LoadAwareScorer { // Pod with requests in the queue will get score between 0.5 and 0. // Score 0 will get pod with number of requests in the queue equal to the threshold used in load-based filter // In the future, pods with additional capacity will get score higher than 0.5 -func (s *LoadAwareScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (s *LoadAware) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { scoredPods := make(map[types.Pod]float64) for _, pod := range pods { diff --git a/pkg/plugins/scorer/load_aware_scorer_test.go b/pkg/plugins/scorer/load_aware_test.go similarity index 96% rename from pkg/plugins/scorer/load_aware_scorer_test.go rename to pkg/plugins/scorer/load_aware_test.go index 71aeec5d..e454d22f 100644 --- a/pkg/plugins/scorer/load_aware_scorer_test.go +++ b/pkg/plugins/scorer/load_aware_test.go @@ -44,7 +44,7 @@ func TestLoadBasedScorer(t *testing.T) { }{ { name: "load based scorer", - scorer: scorer.NewLoadAwareScorer(context.Background(), 10), + scorer: scorer.NewLoadAware(context.Background(), 10), req: &types.LLMRequest{ TargetModel: "critical", }, diff --git a/pkg/plugins/scorer/prefix_cache_tracking.go b/pkg/plugins/scorer/prefix_cache_tracking.go index e6ce0a40..09f64652 100644 --- a/pkg/plugins/scorer/prefix_cache_tracking.go +++ b/pkg/plugins/scorer/prefix_cache_tracking.go @@ -17,7 +17,7 @@ import ( ) // PrefixCacheTrackingConfig holds the configuration for the -// PrefixCacheTrackingScorer. +// PrefixCacheTracking. type PrefixCacheTrackingConfig struct { // IndexerConfig holds the configuration for the `kvcache.Indexer` which is // used to score pods based on the KV-cache index state. @@ -29,7 +29,7 @@ type PrefixCacheTrackingConfig struct { } // compile-time type assertion -var _ framework.Scorer = &PrefixCacheTrackingScorer{} +var _ framework.Scorer = &PrefixCacheTracking{} // PrefixCacheTrackingPluginFactory defines the factory function for creating // a new instance of the PrefixCacheTrackingPlugin. @@ -68,7 +68,7 @@ func PrefixCacheTrackingPluginFactory(name string, rawParameters json.RawMessage // // If the configuration is invalid or if the indexer fails to initialize, // an error is returned. -func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTrackingScorer, error) { +func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTracking, error) { // initialize the indexer kvCacheIndexer, err := kvcache.NewKVCacheIndexer(ctx, config.IndexerConfig) if err != nil { @@ -81,36 +81,36 @@ func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTra pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex()) pool.Start(ctx) - return &PrefixCacheTrackingScorer{ + return &PrefixCacheTracking{ typedName: plugins.TypedName{Type: prefix.PrefixCachePluginType}, kvCacheIndexer: kvCacheIndexer, }, nil } -// PrefixCacheTrackingScorer implements the framework.Scorer interface. +// PrefixCacheTracking implements the framework.Scorer interface. // The scorer implements the `cache_tracking` mode of the prefix cache plugin. // It uses the `kvcache.Indexer` to score pods based on the KV-cache index // state, and the `kvevents.Pool` to subscribe to KV-cache events // to update the internal KV-cache index state. -type PrefixCacheTrackingScorer struct { +type PrefixCacheTracking struct { typedName plugins.TypedName kvCacheIndexer *kvcache.Indexer } // TypedName returns the typed name of the plugin. -func (s *PrefixCacheTrackingScorer) TypedName() plugins.TypedName { +func (s *PrefixCacheTracking) TypedName() plugins.TypedName { return s.typedName } // WithName sets the name of the plugin. -func (s *PrefixCacheTrackingScorer) WithName(name string) *PrefixCacheTrackingScorer { +func (s *PrefixCacheTracking) WithName(name string) *PrefixCacheTracking { s.typedName.Name = name return s } // Score scores the provided pod based on the KVCache index state. // The returned scores are normalized to a range of 0-1. -func (s *PrefixCacheTrackingScorer) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (s *PrefixCacheTracking) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { loggerDebug := log.FromContext(ctx).WithName(s.typedName.String()).V(logutil.DEBUG) if request == nil { loggerDebug.Info("Request is nil, skipping scoring") diff --git a/pkg/plugins/scorer/session_affinity.go b/pkg/plugins/scorer/session_affinity.go index ab6ead87..a20de574 100644 --- a/pkg/plugins/scorer/session_affinity.go +++ b/pkg/plugins/scorer/session_affinity.go @@ -15,8 +15,8 @@ import ( ) const ( - // SessionAffinityScorerType is the type of the SessionAffinityScorer - SessionAffinityScorerType = "session-affinity-scorer" + // SessionAffinityType is the type of the SessionAffinity scorer. + SessionAffinityType = "session-affinity-scorer" sessionTokenHeader = "x-session-token" // name of the session header in request ) @@ -25,15 +25,15 @@ const ( var _ framework.Scorer = &SessionAffinity{} var _ requestcontrol.PostResponse = &SessionAffinity{} -// SessionAffinityScorerFactory defines the factory function for SessionAffinityScorer. -func SessionAffinityScorerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +// SessionAffinityFactory defines the factory function for SessionAffinity scorer. +func SessionAffinityFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { return NewSessionAffinity().WithName(name), nil } // NewSessionAffinity returns a scorer func NewSessionAffinity() *SessionAffinity { return &SessionAffinity{ - typedName: plugins.TypedName{Type: SessionAffinityScorerType}, + typedName: plugins.TypedName{Type: SessionAffinityType}, } } diff --git a/pkg/scheduling/pd/scheduler_test.go b/pkg/scheduling/pd/scheduler_test.go index 6fe7a0bd..fe030796 100644 --- a/pkg/scheduling/pd/scheduler_test.go +++ b/pkg/scheduling/pd/scheduler_test.go @@ -198,14 +198,14 @@ func TestPDSchedule(t *testing.T) { prefixScorer := prefix.New(prefix.Config{HashBlockSize: 5, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250}) prefillSchedulerProfile := framework.NewSchedulerProfile(). - WithFilters(filter.NewPrefillFilter()). + WithFilters(filter.NewPrefillRole()). WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) err := prefillSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 50)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") decodeSchedulerProfile := framework.NewSchedulerProfile(). - WithFilters(filter.NewDecodeFilter()). - WithScorers(framework.NewWeightedScorer(scorer.NewLoadAwareScorer(ctx, scorer.QueueThresholdDefault), 1)). + WithFilters(filter.NewDecodeRole()). + WithScorers(framework.NewWeightedScorer(scorer.NewLoadAware(ctx, scorer.QueueThresholdDefault), 1)). WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) err = decodeSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 0)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") diff --git a/test/config/prefix_cache_mode_test.go b/test/config/prefix_cache_mode_test.go index 1b7cf682..39e5ca04 100644 --- a/test/config/prefix_cache_mode_test.go +++ b/test/config/prefix_cache_mode_test.go @@ -129,9 +129,9 @@ schedulingProfiles: } } else { - _, err := giePlugins.PluginByType[*scorer.PrefixCacheTrackingScorer](handle, test.pluginName) + _, err := giePlugins.PluginByType[*scorer.PrefixCacheTracking](handle, test.pluginName) if err != nil { - t.Fatalf("expected PrefixCacheTrackingScorer, but got error: %v", err) + t.Fatalf("expected PrefixCacheTracking, but got error: %v", err) } }