Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ go.work.sum
# IDE files
.idea
.vscode

vendor
16 changes: 8 additions & 8 deletions pkg/plugins/filter/by_label.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@ import (
)

const (
// ByLabelFilterType is the type of the ByLabel filter
ByLabelFilterType = "by-label"
// ByLabelType is the type of the ByLabel filter
ByLabelType = "by-label"
)

type byLabelFilterParameters struct {
type byLabelParameters struct {
Label string `json:"label"`
ValidValues []string `json:"validValues"`
AllowsNoLabel bool `json:"allowsNoLabel"`
}

var _ framework.Filter = &ByLabel{} // validate interface conformance

// ByLabelFilterFactory defines the factory function for the ByLabelFilter
func ByLabelFilterFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := byLabelFilterParameters{}
// ByLabelFactory defines the factory function for the ByLabel filter.
func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := byLabelParameters{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelFilterType, err)
return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelType, err)
}
}
return NewByLabel(name, parameters.Label, parameters.AllowsNoLabel, parameters.ValidValues...), nil
Expand All @@ -47,7 +47,7 @@ func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues .
}

return &ByLabel{
typedName: plugins.TypedName{Type: ByLabelFilterType, Name: name},
typedName: plugins.TypedName{Type: ByLabelType, Name: name},
labelName: labelName,
allowsNoLabel: allowsNoLabel,
validValues: validValuesMap,
Expand Down
8 changes: 4 additions & 4 deletions pkg/plugins/filter/by_label_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
)

const (
// ByLabelSelectorFilterType is the type of the ByLabelsFilter
ByLabelSelectorFilterType = "by-label-selector"
// ByLabelSelectorType is the type of the ByLabelSelector filter
ByLabelSelectorType = "by-label-selector"
)

// compile-time type assertion
Expand All @@ -26,7 +26,7 @@ func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugin
parameters := metav1.LabelSelector{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelSelectorFilterType, err)
return nil, fmt.Errorf("failed to parse the parameters of the '%s' filter - %w", ByLabelSelectorType, err)
}
}
return NewByLabelSelector(name, &parameters)
Expand All @@ -44,7 +44,7 @@ func NewByLabelSelector(name string, selector *metav1.LabelSelector) (*ByLabelSe
}

return &ByLabelSelector{
typedName: plugins.TypedName{Type: ByLabelSelectorFilterType, Name: name},
typedName: plugins.TypedName{Type: ByLabelSelectorType, Name: name},
selector: labelSelector,
}, nil
}
Expand Down
43 changes: 43 additions & 0 deletions pkg/plugins/filter/pd_role.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package filter

import (
"encoding/json"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
)

const (
// RoleLabel name
RoleLabel = "llm-d.ai/role"
// RolePrefill set for designated prefill workers
RolePrefill = "prefill"
// RoleDecode set for designated decode workers
RoleDecode = "decode"
// RoleBoth set for workers that can act as both prefill and decode
RoleBoth = "both"

// DecodeRoleType is the type of the DecodeFilter
DecodeRoleType = "decode-filter"
// PrefillRoleType is the type of the PrefillFilter
PrefillRoleType = "prefill-filter"
)

// PrefillRoleFactory defines the factory function for the Prefill filter.
func PrefillRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
return NewPrefillRole().WithName(name), nil
}

// NewPrefillRole creates and returns an instance of the Filter configured for prefill role
func NewPrefillRole() *ByLabel {
return NewByLabel(PrefillRoleType, RoleLabel, false, RolePrefill)
}

// DecodeRoleFactory defines the factory function for the Decode filter.
func DecodeRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
return NewDecodeRole().WithName(name), nil
}

// NewDecodeRole creates and returns an instance of the Filter configured for decode role
func NewDecodeRole() *ByLabel {
return NewByLabel(DecodeRoleType, RoleLabel, true, RoleDecode, RoleBoth)
}
43 changes: 0 additions & 43 deletions pkg/plugins/filter/pd_role_filter.go

This file was deleted.

14 changes: 7 additions & 7 deletions pkg/plugins/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ import (

// RegisterAllPlugins registers the factory functions of all plugins in this repository.
func RegisterAllPlugins() {
plugins.Register(filter.ByLabelFilterType, filter.ByLabelFilterFactory)
plugins.Register(filter.ByLabelSelectorFilterType, filter.ByLabelSelectorFactory)
plugins.Register(filter.DecodeFilterType, filter.DecodeFilterFactory)
plugins.Register(filter.PrefillFilterType, filter.PrefillFilterFactory)
plugins.Register(filter.ByLabelType, filter.ByLabelFactory)
plugins.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory)
plugins.Register(filter.DecodeRoleType, filter.DecodeRoleFactory)
plugins.Register(filter.PrefillRoleType, filter.PrefillRoleFactory)
plugins.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory)
plugins.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory)
plugins.Register(prefix.PrefixCachePluginType, scorer.PrefixCachePluginFactory)
plugins.Register(scorer.LoadAwareScorerType, scorer.LoadAwareScorerFactory)
plugins.Register(scorer.SessionAffinityScorerType, scorer.SessionAffinityScorerFactory)
plugins.Register(scorer.ActiveRequestScorerType, scorer.ActiveRequestScorerFactory)
plugins.Register(scorer.LoadAwareType, scorer.LoadAwareFactory)
plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
}
50 changes: 25 additions & 25 deletions pkg/plugins/scorer/active_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ import (
)

const (
// ActiveRequestScorerType is the type of the ActiveRequestScorer
ActiveRequestScorerType = "active-request-scorer"
// ActiveRequestType is the type of the ActiveRequest scorer.
ActiveRequestType = "active-request-scorer"

// defaultRequestTimeout defines the default timeout for open requests to be
// considered stale and removed from the cache.
defaultRequestTimeout = 2 * time.Minute
)

// ActiveRequestScorerParameters defines the parameters for the
// ActiveRequestScorer.
type ActiveRequestScorerParameters struct {
// ActiveRequestParameters defines the parameters for the
// ActiveRequest.
type ActiveRequestParameters struct {
// RequestTimeout defines the timeout for requests in seconds.
// Once the request is "in-flight" for this duration, it is considered to
// be timed out and dropped.
Expand All @@ -48,22 +48,22 @@ func (r *requestEntry) String() string {
}

// compile-time type assertion
var _ framework.Scorer = &ActiveRequestScorer{}
var _ framework.Scorer = &ActiveRequest{}

// ActiveRequestScorerFactory defines the factory function for the ActiveRequestScorer.
func ActiveRequestScorerFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := ActiveRequestScorerParameters{}
// ActiveRequestFactory defines the factory function for the ActiveRequest scorer.
func ActiveRequestFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := ActiveRequestParameters{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", ActiveRequestScorerType, err)
return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", ActiveRequestType, err)
}
}

return NewActiveRequestScorer(handle.Context(), &parameters).WithName(name), nil
return NewActiveRequest(handle.Context(), &parameters).WithName(name), nil
}

// NewActiveRequestScorer creates a new ActiveRequestScorer scorer.
func NewActiveRequestScorer(ctx context.Context, params *ActiveRequestScorerParameters) *ActiveRequestScorer {
// NewActiveRequest creates a new ActiveRequest scorer.
func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *ActiveRequest {
requestTimeout := defaultRequestTimeout
logger := log.FromContext(ctx)

Expand All @@ -83,8 +83,8 @@ func NewActiveRequestScorer(ctx context.Context, params *ActiveRequestScorerPara
ttlcache.WithDisableTouchOnHit[string, *requestEntry](),
)

scorer := &ActiveRequestScorer{
typedName: plugins.TypedName{Type: ActiveRequestScorerType},
scorer := &ActiveRequest{
typedName: plugins.TypedName{Type: ActiveRequestType},
requestCache: requestCache,
podCounts: make(map[string]int),
mutex: &sync.RWMutex{},
Expand All @@ -104,9 +104,9 @@ func NewActiveRequestScorer(ctx context.Context, params *ActiveRequestScorerPara
return scorer
}

// ActiveRequestScorer keeps track of individual requests being served
// ActiveRequest keeps track of individual requests being served
// per pod.
type ActiveRequestScorer struct {
type ActiveRequest struct {
typedName plugins.TypedName

// requestCache stores individual request entries with unique composite keys (podName.requestID)
Expand All @@ -118,19 +118,19 @@ type ActiveRequestScorer struct {
}

// TypedName returns the typed name of the plugin.
func (s *ActiveRequestScorer) TypedName() plugins.TypedName {
func (s *ActiveRequest) TypedName() plugins.TypedName {
return s.typedName
}

// WithName sets the name of the plugin.
func (s *ActiveRequestScorer) WithName(name string) *ActiveRequestScorer {
func (s *ActiveRequest) WithName(name string) *ActiveRequest {
s.typedName.Name = name
return s
}

// Score scores the given pods based on the number of active requests
// being served by each pod. The score is normalized to a range of 0-1.
func (s *ActiveRequestScorer) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest,
func (s *ActiveRequest) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest,
pods []types.Pod) map[types.Pod]float64 {
scoredPods := make(map[string]int)
maxCount := 0
Expand Down Expand Up @@ -164,7 +164,7 @@ func (s *ActiveRequestScorer) Score(ctx context.Context, _ *types.CycleState, _
// PreRequest is called before a request is sent to the target pod.
// It creates a new request entry in the cache with its own TTL and
// increments the pod count for fast lookup.
func (s *ActiveRequestScorer) PreRequest(ctx context.Context, request *types.LLMRequest,
func (s *ActiveRequest) PreRequest(ctx context.Context, request *types.LLMRequest,
schedulingResult *types.SchedulingResult, _ int) {
debugLogger := log.FromContext(ctx).V(logutil.DEBUG)

Expand All @@ -190,9 +190,9 @@ func (s *ActiveRequestScorer) PreRequest(ctx context.Context, request *types.LLM
// PostResponse is called after a response is sent to the client.
// It removes the specific request entry from the cache and decrements
// the pod count.
func (s *ActiveRequestScorer) PostResponse(ctx context.Context, request *types.LLMRequest,
func (s *ActiveRequest) PostResponse(ctx context.Context, request *types.LLMRequest,
_ *requestcontrol.Response, targetPod *backend.Pod) {
debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequestScorer.PostResponse")
debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequest.PostResponse")
if targetPod == nil {
debugLogger.Info("Skipping PostResponse because targetPod is nil")
return
Expand All @@ -209,7 +209,7 @@ func (s *ActiveRequestScorer) PostResponse(ctx context.Context, request *types.L
}

// incrementPodCount increments the request count for a pod.
func (s *ActiveRequestScorer) incrementPodCount(podName string) {
func (s *ActiveRequest) incrementPodCount(podName string) {
s.mutex.Lock()
defer s.mutex.Unlock()

Expand All @@ -218,7 +218,7 @@ func (s *ActiveRequestScorer) incrementPodCount(podName string) {

// decrementPodCount decrements the request count for a pod and removes
// the entry if count reaches zero.
func (s *ActiveRequestScorer) decrementPodCount(podName string) {
func (s *ActiveRequest) decrementPodCount(podName string) {
s.mutex.Lock()
defer s.mutex.Unlock()

Expand Down
Loading