diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 2842704bb..47b5f8418 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -244,7 +244,7 @@ func (r *Runner) Run(ctx context.Context) error { runtime.SetBlockProfileRate(1) } - err = r.parsePluginsConfiguration(ctx) + err = r.parsePluginsConfiguration(ctx, datastore) if err != nil { setupLog.Error(err, "Failed to parse plugins configuration") return err @@ -321,7 +321,7 @@ func (r *Runner) registerInTreePlugins() { plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory) } -func (r *Runner) parsePluginsConfiguration(ctx context.Context) error { +func (r *Runner) parsePluginsConfiguration(ctx context.Context, datastore datastore.Datastore) error { if *configText == "" && *configFile == "" { return nil // configuring through code, not through file } @@ -340,7 +340,7 @@ func (r *Runner) parsePluginsConfiguration(ctx context.Context) error { } r.registerInTreePlugins() - handle := plugins.NewEppHandle(ctx) + handle := plugins.NewEppHandle(ctx, datastore) config, err := loader.LoadConfig(configBytes, handle, logger) if err != nil { return fmt.Errorf("failed to load the configuration - %w", err) diff --git a/pkg/epp/config/loader/configloader_test.go b/pkg/epp/config/loader/configloader_test.go index b0b0741c1..02ee8d054 100644 --- a/pkg/epp/config/loader/configloader_test.go +++ b/pkg/epp/config/loader/configloader_test.go @@ -303,7 +303,7 @@ func TestLoadRawConfigurationWithDefaults(t *testing.T) { } for _, test := range tests { - handle := utils.NewTestHandle(context.Background()) + handle := utils.NewTestHandle(context.Background(), nil) configBytes := []byte(test.configText) if test.configFile != "" { @@ -346,7 +346,7 @@ func checkError(t *testing.T, function string, test testStruct, err error) { } func TestInstantiatePlugins(t *testing.T) { - handle := utils.NewTestHandle(context.Background()) + handle := utils.NewTestHandle(context.Background(), nil) _, err := LoadConfig([]byte(successConfigText), handle, logging.NewTestLogger()) if err != nil { t.Fatalf("LoadConfig returned unexpected error - %v", err) @@ -360,7 +360,7 @@ func TestInstantiatePlugins(t *testing.T) { t.Fatalf("loaded plugins returned test1 has the wrong type %#v", t1) } - handle = utils.NewTestHandle(context.Background()) + handle = utils.NewTestHandle(context.Background(), nil) _, err = LoadConfig([]byte(errorBadPluginReferenceParametersText), handle, logging.NewTestLogger()) if err == nil { t.Fatalf("LoadConfig did not return error as expected ") @@ -426,7 +426,7 @@ func TestLoadConfig(t *testing.T) { logger := logging.NewTestLogger() for _, test := range tests { - handle := utils.NewTestHandle(context.Background()) + handle := utils.NewTestHandle(context.Background(), nil) _, err := LoadConfig([]byte(test.configText), handle, logger) if err != nil { if !test.wantErr { diff --git a/pkg/epp/plugins/handle.go b/pkg/epp/plugins/handle.go index 8c9153cf1..db4606546 100644 --- a/pkg/epp/plugins/handle.go +++ b/pkg/epp/plugins/handle.go @@ -19,6 +19,10 @@ package plugins import ( "context" "fmt" + + v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" ) // Handle provides plugins a set of standard data and tools to work with @@ -27,6 +31,8 @@ type Handle interface { Context() context.Context HandlePlugins + + Datastore } // HandlePlugins defines a set of APIs to work with instantiated plugins @@ -44,10 +50,18 @@ type HandlePlugins interface { GetAllPluginsWithNames() map[string]Plugin } +// Datastore defines the interface required by the Director. +type Datastore interface { + PoolGet() (*v1.InferencePool, error) + ObjectiveGet(modelName string) *v1alpha2.InferenceObjective + PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics +} + // eppHandle is an implementation of the interface plugins.Handle type eppHandle struct { ctx context.Context HandlePlugins + Datastore } // Context returns a context the plugins can use, if they need one @@ -84,12 +98,13 @@ func (h *eppHandlePlugins) GetAllPluginsWithNames() map[string]Plugin { return h.plugins } -func NewEppHandle(ctx context.Context) Handle { +func NewEppHandle(ctx context.Context, datastore Datastore) Handle { return &eppHandle{ ctx: ctx, HandlePlugins: &eppHandlePlugins{ plugins: map[string]Plugin{}, }, + Datastore: datastore, } } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index a3e2d6d13..be60c22ae 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -65,6 +65,7 @@ func NewDirectorWithConfig(datastore Datastore, scheduler Scheduler, saturationD datastore: datastore, scheduler: scheduler, saturationDetector: saturationDetector, + preSchedulePlugins: config.preSchedulePlugins, preRequestPlugins: config.preRequestPlugins, postResponsePlugins: config.postResponsePlugins, defaultPriority: 0, // define default priority explicitly @@ -76,6 +77,7 @@ type Director struct { datastore Datastore scheduler Scheduler saturationDetector SaturationDetector + preSchedulePlugins []PreSchedule preRequestPlugins []PreRequest postResponsePlugins []PostResponse // we just need a pointer to an int variable since priority is a pointer in InferenceObjective @@ -135,7 +137,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo logger.V(logutil.DEBUG).Info("LLM request assembled") // Get candidate pods for scheduling - candidatePods := d.getCandidatePodsForScheduling(ctx, reqCtx.Request.Metadata) + candidatePods := d.runPreSchedulePlugins(ctx, reqCtx.Request) if len(candidatePods) == 0 { return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"} } @@ -161,24 +163,24 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo return reqCtx, nil } -// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore. +// GetCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore. // according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies // a subset of endpoints, only these endpoints will be considered as candidates for the scheduler. // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles. -func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMetadata map[string]any) []backendmetrics.PodMetrics { +func GetCandidatePodsForScheduling(ctx context.Context, datastore Datastore, requestMetadata map[string]any) []backendmetrics.PodMetrics { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) subsetMap, found := requestMetadata[metadata.SubsetFilterNamespace].(map[string]any) if !found { - return d.datastore.PodList(backendmetrics.AllPodsPredicate) + return datastore.PodList(backendmetrics.AllPodsPredicate) } // Check if endpoint key is present in the subset map and ensure there is at least one value endpointSubsetList, found := subsetMap[metadata.SubsetFilterKey].([]any) if !found { - return d.datastore.PodList(backendmetrics.AllPodsPredicate) + return datastore.PodList(backendmetrics.AllPodsPredicate) } else if len(endpointSubsetList) == 0 { loggerTrace.Info("found empty subset filter in request metadata, filtering all pods") return []backendmetrics.PodMetrics{} @@ -194,7 +196,7 @@ func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMet } podTotalCount := 0 - podFilteredList := d.datastore.PodList(func(pm backendmetrics.PodMetrics) bool { + podFilteredList := datastore.PodList(func(pm backendmetrics.PodMetrics) bool { podTotalCount++ if _, found := endpoints[pm.GetPod().Address]; found { return true @@ -301,6 +303,24 @@ func (d *Director) GetRandomPod() *backend.Pod { return pod.GetPod() } +func (d *Director) runPreSchedulePlugins(ctx context.Context, request *handlers.Request) []backendmetrics.PodMetrics { + loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) + candidates := []backendmetrics.PodMetrics{} + if len(d.preSchedulePlugins) > 0 { + for _, plugin := range d.preSchedulePlugins { + loggerDebug.Info("Running pre-schedule plugin", "plugin", plugin.TypedName()) + before := time.Now() + candidates = append(candidates, plugin.GetCandidatePods(ctx, request)...) + metrics.RecordPluginProcessingLatency(PreScheduleExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) + loggerDebug.Info("Completed running pre-schedule plugin successfully", "plugin", plugin.TypedName()) + } + } else { + // There are no PreSchedule plugins, fallback... + candidates = GetCandidatePodsForScheduling(ctx, d.datastore, request.Metadata) + } + return candidates +} + func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult, targetPort int) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index a0cb7c325..c11200154 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -529,9 +529,7 @@ func TestGetCandidatePodsForScheduling(t *testing.T) { ds := &mockDatastore{pods: testInput} for _, test := range tests { t.Run(test.name, func(t *testing.T) { - director := NewDirectorWithConfig(ds, &mockScheduler{}, &mockSaturationDetector{}, NewConfig()) - - got := director.getCandidatePodsForScheduling(context.Background(), test.metadata) + got := GetCandidatePodsForScheduling(context.Background(), ds, test.metadata) diff := cmp.Diff(test.output, got, cmpopts.SortSlices(func(a, b backendmetrics.PodMetrics) bool { return a.GetPod().NamespacedName.String() < b.GetPod().NamespacedName.String() diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index ca823a670..e993925fd 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -20,15 +20,25 @@ import ( "context" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) const ( + PreScheduleExtensionPoint = "PreSchedule" PreRequestExtensionPoint = "PreRequest" PostResponseExtensionPoint = "PostResponse" ) +// PreSchedule is called by the director before sending the request to the scheduler. +// It gets the set of candidate pods to be filtered and scored. +type PreSchedule interface { + plugins.Plugin + GetCandidatePods(ctx context.Context, request *handlers.Request) []backendmetrics.PodMetrics +} + // PreRequest is called by the director after a getting result from scheduling layer and // before a request is sent to the selected model server. type PreRequest interface { diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index 2d6dc95e7..4eba65bc7 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -23,6 +23,7 @@ import ( // NewConfig creates a new Config object and returns its pointer. func NewConfig() *Config { return &Config{ + preSchedulePlugins: []PreSchedule{}, preRequestPlugins: []PreRequest{}, postResponsePlugins: []PostResponse{}, } @@ -30,10 +31,18 @@ func NewConfig() *Config { // Config provides a configuration for the requestcontrol plugins. type Config struct { + preSchedulePlugins []PreSchedule preRequestPlugins []PreRequest postResponsePlugins []PostResponse } +// WithPreSchedulePlugins sets the given plugins as the PreSchedule plugins. +// If the Config has PreSchedule plugins already, this call replaces the existing plugins with the given ones. +func (c *Config) WithPreSchedulePlugins(plugins ...PreSchedule) *Config { + c.preSchedulePlugins = plugins + return c +} + // WithPreRequestPlugins sets the given plugins as the PreRequest plugins. // If the Config has PreRequest plugins already, this call replaces the existing plugins with the given ones. func (c *Config) WithPreRequestPlugins(plugins ...PreRequest) *Config { @@ -50,6 +59,9 @@ func (c *Config) WithPostResponsePlugins(plugins ...PostResponse) *Config { func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { for _, plugin := range pluginObjects { + if preSchedulePlugin, ok := plugin.(PreSchedule); ok { + c.preSchedulePlugins = append(c.preSchedulePlugins, preSchedulePlugin) + } if preRequestPlugin, ok := plugin.(PreRequest); ok { c.preRequestPlugins = append(c.preRequestPlugins, preRequestPlugin) } diff --git a/test/utils/handle.go b/test/utils/handle.go index 4a29dda87..5d110a672 100644 --- a/test/utils/handle.go +++ b/test/utils/handle.go @@ -26,6 +26,7 @@ import ( type testHandle struct { ctx context.Context plugins.HandlePlugins + plugins.Datastore } // Context returns a context the plugins can use, if they need one @@ -57,11 +58,12 @@ func (h *testHandlePlugins) GetAllPluginsWithNames() map[string]plugins.Plugin { return h.plugins } -func NewTestHandle(ctx context.Context) plugins.Handle { +func NewTestHandle(ctx context.Context, datastore plugins.Datastore) plugins.Handle { return &testHandle{ ctx: ctx, HandlePlugins: &testHandlePlugins{ plugins: map[string]plugins.Plugin{}, }, + Datastore: datastore, } }