Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (r *Runner) Run(ctx context.Context) error {
runtime.SetBlockProfileRate(1)
}

err = r.parsePluginsConfiguration(ctx)
err = r.parsePluginsConfiguration(ctx, datastore)
if err != nil {
setupLog.Error(err, "Failed to parse plugins configuration")
return err
Expand Down Expand Up @@ -321,7 +321,7 @@ func (r *Runner) registerInTreePlugins() {
plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory)
}

func (r *Runner) parsePluginsConfiguration(ctx context.Context) error {
func (r *Runner) parsePluginsConfiguration(ctx context.Context, datastore datastore.Datastore) error {
if *configText == "" && *configFile == "" {
return nil // configuring through code, not through file
}
Expand All @@ -340,7 +340,7 @@ func (r *Runner) parsePluginsConfiguration(ctx context.Context) error {
}

r.registerInTreePlugins()
handle := plugins.NewEppHandle(ctx)
handle := plugins.NewEppHandle(ctx, datastore)
config, err := loader.LoadConfig(configBytes, handle, logger)
if err != nil {
return fmt.Errorf("failed to load the configuration - %w", err)
Expand Down
8 changes: 4 additions & 4 deletions pkg/epp/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ func TestLoadRawConfigurationWithDefaults(t *testing.T) {
}

for _, test := range tests {
handle := utils.NewTestHandle(context.Background())
handle := utils.NewTestHandle(context.Background(), nil)

configBytes := []byte(test.configText)
if test.configFile != "" {
Expand Down Expand Up @@ -346,7 +346,7 @@ func checkError(t *testing.T, function string, test testStruct, err error) {
}

func TestInstantiatePlugins(t *testing.T) {
handle := utils.NewTestHandle(context.Background())
handle := utils.NewTestHandle(context.Background(), nil)
_, err := LoadConfig([]byte(successConfigText), handle, logging.NewTestLogger())
if err != nil {
t.Fatalf("LoadConfig returned unexpected error - %v", err)
Expand All @@ -360,7 +360,7 @@ func TestInstantiatePlugins(t *testing.T) {
t.Fatalf("loaded plugins returned test1 has the wrong type %#v", t1)
}

handle = utils.NewTestHandle(context.Background())
handle = utils.NewTestHandle(context.Background(), nil)
_, err = LoadConfig([]byte(errorBadPluginReferenceParametersText), handle, logging.NewTestLogger())
if err == nil {
t.Fatalf("LoadConfig did not return error as expected ")
Expand Down Expand Up @@ -426,7 +426,7 @@ func TestLoadConfig(t *testing.T) {

logger := logging.NewTestLogger()
for _, test := range tests {
handle := utils.NewTestHandle(context.Background())
handle := utils.NewTestHandle(context.Background(), nil)
_, err := LoadConfig([]byte(test.configText), handle, logger)
if err != nil {
if !test.wantErr {
Expand Down
17 changes: 16 additions & 1 deletion pkg/epp/plugins/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package plugins
import (
"context"
"fmt"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
)

// Handle provides plugins a set of standard data and tools to work with
Expand All @@ -27,6 +31,8 @@ type Handle interface {
Context() context.Context

HandlePlugins

Datastore
}

// HandlePlugins defines a set of APIs to work with instantiated plugins
Expand All @@ -44,10 +50,18 @@ type HandlePlugins interface {
GetAllPluginsWithNames() map[string]Plugin
}

// Datastore defines the interface required by the Director.
type Datastore interface {
PoolGet() (*v1.InferencePool, error)
ObjectiveGet(modelName string) *v1alpha2.InferenceObjective
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
}

// eppHandle is an implementation of the interface plugins.Handle
type eppHandle struct {
ctx context.Context
HandlePlugins
Datastore
}

// Context returns a context the plugins can use, if they need one
Expand Down Expand Up @@ -84,12 +98,13 @@ func (h *eppHandlePlugins) GetAllPluginsWithNames() map[string]Plugin {
return h.plugins
}

func NewEppHandle(ctx context.Context) Handle {
func NewEppHandle(ctx context.Context, datastore Datastore) Handle {
return &eppHandle{
ctx: ctx,
HandlePlugins: &eppHandlePlugins{
plugins: map[string]Plugin{},
},
Datastore: datastore,
}
}

Expand Down
32 changes: 26 additions & 6 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func NewDirectorWithConfig(datastore Datastore, scheduler Scheduler, saturationD
datastore: datastore,
scheduler: scheduler,
saturationDetector: saturationDetector,
preSchedulePlugins: config.preSchedulePlugins,
preRequestPlugins: config.preRequestPlugins,
postResponsePlugins: config.postResponsePlugins,
defaultPriority: 0, // define default priority explicitly
Expand All @@ -76,6 +77,7 @@ type Director struct {
datastore Datastore
scheduler Scheduler
saturationDetector SaturationDetector
preSchedulePlugins []PreSchedule
preRequestPlugins []PreRequest
postResponsePlugins []PostResponse
// we just need a pointer to an int variable since priority is a pointer in InferenceObjective
Expand Down Expand Up @@ -135,7 +137,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
logger.V(logutil.DEBUG).Info("LLM request assembled")

// Get candidate pods for scheduling
candidatePods := d.getCandidatePodsForScheduling(ctx, reqCtx.Request.Metadata)
candidatePods := d.runPreSchedulePlugins(ctx, reqCtx.Request)
if len(candidatePods) == 0 {
return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"}
}
Expand All @@ -161,24 +163,24 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
return reqCtx, nil
}

// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
// GetCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.
// Snapshot pod metrics from the datastore to:
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles.
func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMetadata map[string]any) []backendmetrics.PodMetrics {
func GetCandidatePodsForScheduling(ctx context.Context, datastore Datastore, requestMetadata map[string]any) []backendmetrics.PodMetrics {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)

subsetMap, found := requestMetadata[metadata.SubsetFilterNamespace].(map[string]any)
if !found {
return d.datastore.PodList(backendmetrics.AllPodsPredicate)
return datastore.PodList(backendmetrics.AllPodsPredicate)
}

// Check if endpoint key is present in the subset map and ensure there is at least one value
endpointSubsetList, found := subsetMap[metadata.SubsetFilterKey].([]any)
if !found {
return d.datastore.PodList(backendmetrics.AllPodsPredicate)
return datastore.PodList(backendmetrics.AllPodsPredicate)
} else if len(endpointSubsetList) == 0 {
loggerTrace.Info("found empty subset filter in request metadata, filtering all pods")
return []backendmetrics.PodMetrics{}
Expand All @@ -194,7 +196,7 @@ func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMet
}

podTotalCount := 0
podFilteredList := d.datastore.PodList(func(pm backendmetrics.PodMetrics) bool {
podFilteredList := datastore.PodList(func(pm backendmetrics.PodMetrics) bool {
podTotalCount++
if _, found := endpoints[pm.GetPod().Address]; found {
return true
Expand Down Expand Up @@ -301,6 +303,24 @@ func (d *Director) GetRandomPod() *backend.Pod {
return pod.GetPod()
}

func (d *Director) runPreSchedulePlugins(ctx context.Context, request *handlers.Request) []backendmetrics.PodMetrics {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
candidates := []backendmetrics.PodMetrics{}
if len(d.preSchedulePlugins) > 0 {
for _, plugin := range d.preSchedulePlugins {
loggerDebug.Info("Running pre-schedule plugin", "plugin", plugin.TypedName())
before := time.Now()
candidates = append(candidates, plugin.GetCandidatePods(ctx, request)...)
metrics.RecordPluginProcessingLatency(PreScheduleExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running pre-schedule plugin successfully", "plugin", plugin.TypedName())
}
} else {
// There are no PreSchedule plugins, fallback...
candidates = GetCandidatePodsForScheduling(ctx, d.datastore, request.Metadata)
}
return candidates
}

func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest,
schedulingResult *schedulingtypes.SchedulingResult, targetPort int) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
Expand Down
4 changes: 1 addition & 3 deletions pkg/epp/requestcontrol/director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,7 @@ func TestGetCandidatePodsForScheduling(t *testing.T) {
ds := &mockDatastore{pods: testInput}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
director := NewDirectorWithConfig(ds, &mockScheduler{}, &mockSaturationDetector{}, NewConfig())

got := director.getCandidatePodsForScheduling(context.Background(), test.metadata)
got := GetCandidatePodsForScheduling(context.Background(), ds, test.metadata)

diff := cmp.Diff(test.output, got, cmpopts.SortSlices(func(a, b backendmetrics.PodMetrics) bool {
return a.GetPod().NamespacedName.String() < b.GetPod().NamespacedName.String()
Expand Down
10 changes: 10 additions & 0 deletions pkg/epp/requestcontrol/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,25 @@ import (
"context"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

const (
PreScheduleExtensionPoint = "PreSchedule"
PreRequestExtensionPoint = "PreRequest"
PostResponseExtensionPoint = "PostResponse"
)

// PreSchedule is called by the director before sending the request to the scheduler.
// It gets the set of candidate pods to be filtered and scored.
type PreSchedule interface {
plugins.Plugin
GetCandidatePods(ctx context.Context, request *handlers.Request) []backendmetrics.PodMetrics
}

// PreRequest is called by the director after a getting result from scheduling layer and
// before a request is sent to the selected model server.
type PreRequest interface {
Expand Down
12 changes: 12 additions & 0 deletions pkg/epp/requestcontrol/request_control_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@ import (
// NewConfig creates a new Config object and returns its pointer.
func NewConfig() *Config {
return &Config{
preSchedulePlugins: []PreSchedule{},
preRequestPlugins: []PreRequest{},
postResponsePlugins: []PostResponse{},
}
}

// Config provides a configuration for the requestcontrol plugins.
type Config struct {
preSchedulePlugins []PreSchedule
preRequestPlugins []PreRequest
postResponsePlugins []PostResponse
}

// WithPreSchedulePlugins sets the given plugins as the PreSchedule plugins.
// If the Config has PreSchedule plugins already, this call replaces the existing plugins with the given ones.
func (c *Config) WithPreSchedulePlugins(plugins ...PreSchedule) *Config {
c.preSchedulePlugins = plugins
return c
}

// WithPreRequestPlugins sets the given plugins as the PreRequest plugins.
// If the Config has PreRequest plugins already, this call replaces the existing plugins with the given ones.
func (c *Config) WithPreRequestPlugins(plugins ...PreRequest) *Config {
Expand All @@ -50,6 +59,9 @@ func (c *Config) WithPostResponsePlugins(plugins ...PostResponse) *Config {

func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) {
for _, plugin := range pluginObjects {
if preSchedulePlugin, ok := plugin.(PreSchedule); ok {
c.preSchedulePlugins = append(c.preSchedulePlugins, preSchedulePlugin)
}
if preRequestPlugin, ok := plugin.(PreRequest); ok {
c.preRequestPlugins = append(c.preRequestPlugins, preRequestPlugin)
}
Expand Down
4 changes: 3 additions & 1 deletion test/utils/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
type testHandle struct {
ctx context.Context
plugins.HandlePlugins
plugins.Datastore
}

// Context returns a context the plugins can use, if they need one
Expand Down Expand Up @@ -57,11 +58,12 @@ func (h *testHandlePlugins) GetAllPluginsWithNames() map[string]plugins.Plugin {
return h.plugins
}

func NewTestHandle(ctx context.Context) plugins.Handle {
func NewTestHandle(ctx context.Context, datastore plugins.Datastore) plugins.Handle {
return &testHandle{
ctx: ctx,
HandlePlugins: &testHandlePlugins{
plugins: map[string]plugins.Plugin{},
},
Datastore: datastore,
}
}