-
Notifications
You must be signed in to change notification settings - Fork 60
feat: Add external plugins support - DO NOT MERGE #183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
package config | ||
|
||
import ( | ||
"encoding/json" | ||
|
||
"github.com/go-logr/logr" | ||
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" | ||
) | ||
|
@@ -52,15 +54,49 @@ const ( | |
|
||
prefixScorerBlockSizeEnvKey = "PREFIX_SCORER_BLOCK_SIZE" | ||
prefixScorerBlockSizeDefault = 256 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point the external process plugin is for demo only, so I would refrain from introducing and using all these environment variables (which would be eliminated by the config changes anyway). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For a demo, you can just enable (or not) a fixed type plugins (e.g., filter or scorer) at a well known address (e.g., |
||
externalPrefix = "EXTERNAL_" | ||
httpPrefix = "HTTP_" | ||
|
||
preSchedulers = "PRE_SCHEDULERS" | ||
filters = "FILTERS" | ||
scorers = "SCORERS" | ||
postSchedulers = "POST_SCHEDULERS" | ||
|
||
// EXTERNAL_HTTP_PRE_SCHEDULERS | ||
// EXTERNAL_PREFILL_HTTP_PRE_SCHEDULERS | ||
// EXTERNAL_HTTP_FILTERS | ||
// EXTERNAL_PREFILL_HTTP_FILTERS | ||
// EXTERNAL_HTTP_SCORERS | ||
// EXTERNAL_PREFILL_HTTP_SCORERS | ||
// EXTERNAL_HTTP_POST_SCHEDULERS | ||
// EXTERNAL_PREFILL_HTTP_POST_SCHEDULERS | ||
) | ||
|
||
// ExternalPluginInfo configuration of an external plugin | ||
type ExternalPluginInfo struct { | ||
Name string `json:"name"` | ||
URL string `json:"url"` | ||
Weight int `json:"weight"` | ||
} | ||
|
||
// ExternalPlugins contains all types of external plugins configuration | ||
type ExternalPlugins struct { | ||
PreSchedulers []ExternalPluginInfo | ||
Filters []ExternalPluginInfo | ||
Scorers []ExternalPluginInfo | ||
PostSchedulers []ExternalPluginInfo | ||
} | ||
|
||
// Config contains scheduler configuration, currently configuration is loaded from environment variables | ||
type Config struct { | ||
DecodeSchedulerPlugins map[string]int | ||
PrefillSchedulerPlugins map[string]int | ||
PDEnabled bool | ||
PDThreshold int | ||
PrefixBlockSize int | ||
DecodeSchedulerPlugins map[string]int | ||
PrefillSchedulerPlugins map[string]int | ||
DecodeSchedulerExternalPlugins ExternalPlugins | ||
PrefillSchedulerExternalPlugins ExternalPlugins | ||
PDEnabled bool | ||
PDThreshold int | ||
PrefixBlockSize int | ||
} | ||
|
||
// LoadConfig loads configuration from environment variables and returns a new instance of Config | ||
|
@@ -72,12 +108,29 @@ func LoadConfig(logger logr.Logger) *Config { | |
GIEKVCacheUtilizationScorerName, GIEQueueScorerName, GIEPrefixScorerName, | ||
} | ||
|
||
// load external plugins for decode and prefill schedulers | ||
prefillSchedulerExternalPlugins := ExternalPlugins{ | ||
PreSchedulers: loadExternalPluginsInfo(logger, httpPrefix, "", preSchedulers), | ||
Filters: loadExternalPluginsInfo(logger, httpPrefix, "", filters), | ||
Scorers: loadExternalPluginsInfo(logger, httpPrefix, "", scorers), | ||
PostSchedulers: loadExternalPluginsInfo(logger, httpPrefix, "", postSchedulers), | ||
} | ||
|
||
decodeSchedulerExternalPlugins := ExternalPlugins{ | ||
PreSchedulers: loadExternalPluginsInfo(logger, httpPrefix, prefillPrefix, preSchedulers), | ||
Filters: loadExternalPluginsInfo(logger, httpPrefix, prefillPrefix, filters), | ||
Scorers: loadExternalPluginsInfo(logger, httpPrefix, prefillPrefix, scorers), | ||
PostSchedulers: loadExternalPluginsInfo(logger, httpPrefix, prefillPrefix, postSchedulers), | ||
} | ||
|
||
return &Config{ | ||
DecodeSchedulerPlugins: loadPluginInfo(logger, false, pluginNames), | ||
PrefillSchedulerPlugins: loadPluginInfo(logger, true, pluginNames), | ||
PDEnabled: env.GetEnvString(pdEnabledEnvKey, "false", logger) == "true", | ||
PDThreshold: env.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger), | ||
PrefixBlockSize: env.GetEnvInt(prefixScorerBlockSizeEnvKey, prefixScorerBlockSizeDefault, logger), | ||
DecodeSchedulerPlugins: loadPluginInfo(logger, false, pluginNames), | ||
PrefillSchedulerPlugins: loadPluginInfo(logger, true, pluginNames), | ||
DecodeSchedulerExternalPlugins: prefillSchedulerExternalPlugins, | ||
PrefillSchedulerExternalPlugins: decodeSchedulerExternalPlugins, | ||
PDEnabled: env.GetEnvString(pdEnabledEnvKey, "false", logger) == "true", | ||
PDThreshold: env.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger), | ||
PrefixBlockSize: env.GetEnvInt(prefixScorerBlockSizeEnvKey, prefixScorerBlockSizeDefault, logger), | ||
} | ||
} | ||
|
||
|
@@ -107,3 +160,26 @@ func loadPluginInfo(logger logr.Logger, prefill bool, pluginNames []string) map[ | |
|
||
return result | ||
} | ||
|
||
// loadExternalPluginsInfo loads configuration of external plugins for the given scheduler type and the given plugins type | ||
// | ||
//nolint:unparam // future: protocol will support more values (grpc, wasm, etc.) | ||
func loadExternalPluginsInfo(logger logr.Logger, protocol string, schedulerType string, pluginType string) []ExternalPluginInfo { | ||
var plugins []ExternalPluginInfo | ||
|
||
envVarName := externalPrefix + protocol + schedulerType + pluginType | ||
envVarRawValue := env.GetEnvString(envVarName, "", logger) | ||
|
||
if envVarRawValue == "" { | ||
logger.Info("Environment variable is not defined", "var", envVarName) | ||
return plugins | ||
} | ||
|
||
if err := json.Unmarshal([]byte(envVarRawValue), &plugins); err != nil { | ||
logger.Info("Error in environment variable unmarshaling", "error", err, "variable", envVarName, "value", envVarRawValue) | ||
return plugins | ||
} | ||
|
||
logger.Info("External plugin loaded", "type", pluginType, "plugins", plugins) | ||
return plugins | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ import ( | |
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" | ||
|
||
"github.com/llm-d/llm-d-inference-scheduler/pkg/config" | ||
externalhttp "github.com/llm-d/llm-d-inference-scheduler/pkg/scheduling/plugins/external/http" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import order seems wrong |
||
"github.com/llm-d/llm-d-inference-scheduler/pkg/scheduling/plugins/filter" | ||
"github.com/llm-d/llm-d-inference-scheduler/pkg/scheduling/plugins/scorer" | ||
) | ||
|
@@ -70,13 +71,13 @@ func NewScheduler(ctx context.Context, schedulerConfig *config.Config, ds Datast | |
|
||
scheduler.prefill = scheduling.NewSchedulerWithConfig( | ||
ds, | ||
scheduler.generateSchedulerConfig(ctx, schedulerConfig.PrefillSchedulerPlugins, | ||
scheduler.generateSchedulerConfig(ctx, schedulerConfig.PrefillSchedulerPlugins, schedulerConfig.PrefillSchedulerExternalPlugins, | ||
&filter.PrefillFilter{}), | ||
) | ||
|
||
scheduler.decode = scheduling.NewSchedulerWithConfig( | ||
ds, | ||
scheduler.generateSchedulerConfig(ctx, schedulerConfig.DecodeSchedulerPlugins, | ||
scheduler.generateSchedulerConfig(ctx, schedulerConfig.DecodeSchedulerPlugins, schedulerConfig.DecodeSchedulerExternalPlugins, | ||
&filter.DecodeFilter{}), | ||
) | ||
|
||
|
@@ -212,7 +213,57 @@ func (s *Scheduler) pluginsFromConfig(ctx context.Context, pluginsConfig map[str | |
return plugins | ||
} | ||
|
||
func (s *Scheduler) generateSchedulerConfig(ctx context.Context, pluginsConfig map[string]int, extraFilters ...plugins.Filter) *scheduling.SchedulerConfig { | ||
func externalFiltersFromConfig(ctx context.Context, info []config.ExternalPluginInfo) []plugins.Filter { | ||
logger := log.FromContext(ctx) | ||
filters := make([]plugins.Filter, 0) | ||
|
||
for _, extPluginInfo := range info { | ||
filters = append(filters, externalhttp.NewFilter(ctx, extPluginInfo.Name, extPluginInfo.URL)) | ||
} | ||
|
||
logger.Info(fmt.Sprintf("Created %d external filters", len(filters))) | ||
return filters | ||
} | ||
|
||
func externalPreSchedulesFromConfig(ctx context.Context, info []config.ExternalPluginInfo) []plugins.PreSchedule { | ||
logger := log.FromContext(ctx) | ||
preSchedules := []plugins.PreSchedule{} | ||
|
||
for _, extPluginInfo := range info { | ||
preSchedules = append(preSchedules, externalhttp.NewPreSchedule(ctx, extPluginInfo.Name, extPluginInfo.URL)) | ||
} | ||
|
||
logger.Info(fmt.Sprintf("Created %d external pre-schedules", len(preSchedules))) | ||
return preSchedules | ||
} | ||
|
||
func externalPostSchedulesFromConfig(ctx context.Context, info []config.ExternalPluginInfo) []plugins.PostSchedule { | ||
logger := log.FromContext(ctx) | ||
postSchedules := []plugins.PostSchedule{} | ||
|
||
for _, extPluginInfo := range info { | ||
postSchedules = append(postSchedules, externalhttp.NewPostSchedule(ctx, extPluginInfo.Name, extPluginInfo.URL)) | ||
} | ||
|
||
logger.Info(fmt.Sprintf("Created %d external post-schedules", len(postSchedules))) | ||
return postSchedules | ||
} | ||
|
||
func externalScorersFromConfig(ctx context.Context, info []config.ExternalPluginInfo) []*giescorer.WeightedScorer { | ||
logger := log.FromContext(ctx) | ||
scorers := []*giescorer.WeightedScorer{} | ||
|
||
for _, extPluginInfo := range info { | ||
scorers = append(scorers, giescorer.NewWeightedScorer(externalhttp.NewScorer(ctx, extPluginInfo.Name, extPluginInfo.URL), extPluginInfo.Weight)) | ||
} | ||
|
||
logger.Info(fmt.Sprintf("Created %d external scorers", len(scorers))) | ||
return scorers | ||
} | ||
|
||
func (s *Scheduler) generateSchedulerConfig(ctx context.Context, pluginsConfig map[string]int, | ||
externalPlugins config.ExternalPlugins, extraFilters ...plugins.Filter) *scheduling.SchedulerConfig { | ||
|
||
thePlugins := s.pluginsFromConfig(ctx, pluginsConfig) | ||
preSchedulePlugins := []plugins.PreSchedule{} | ||
filters := []plugins.Filter{} | ||
|
@@ -240,6 +291,13 @@ func (s *Scheduler) generateSchedulerConfig(ctx context.Context, pluginsConfig m | |
} | ||
} | ||
|
||
// add external plugins | ||
preSchedulePlugins = append(preSchedulePlugins, externalPreSchedulesFromConfig(ctx, externalPlugins.PreSchedulers)...) | ||
filters = append(filters, externalFiltersFromConfig(ctx, externalPlugins.Filters)...) | ||
scorers = append(scorers, externalScorersFromConfig(ctx, externalPlugins.Scorers)...) | ||
postSchedulePlugins = append(postSchedulePlugins, externalPostSchedulesFromConfig(ctx, externalPlugins.PostSchedulers)...) | ||
// postResponsePlugins = append(postResponsePlugins, postResponse) | ||
|
||
return scheduling.NewSchedulerConfig(). | ||
WithPreSchedulePlugins(preSchedulePlugins...). | ||
WithFilters(filters...). | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// Package http contains all types of external http plugins | ||
package http | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"path" | ||
|
||
"github.com/valyala/fasthttp" | ||
|
||
"sigs.k8s.io/controller-runtime/pkg/log" | ||
|
||
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" | ||
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" | ||
) | ||
|
||
// Filter implementation of the external http filter | ||
type Filter struct { | ||
Plugin | ||
} | ||
|
||
var _ plugins.Filter = &Filter{} // validate interface conformance | ||
|
||
// NewFilter creates a new instance of external http filter based on the given parameters | ||
func NewFilter(_ context.Context, name string, url string) plugins.Filter { | ||
return &Filter{Plugin{name: name, url: url}} | ||
} | ||
|
||
// Filter filters the given list of pods | ||
func (f *Filter) Filter(schedContext *types.SchedulingContext, pods []types.Pod) []types.Pod { | ||
logger := log.FromContext(schedContext).WithName(f.Name()) | ||
|
||
// Create filter http request payload based on the given data | ||
filterPayload := newFilterPayload(schedContext, pods) | ||
payload, err := json.Marshal(filterPayload) | ||
|
||
if err != nil { | ||
logger.Error(err, "Failed to marshal scheduling context, filter will be skipped") | ||
return pods | ||
} | ||
|
||
// Create a new fasthttp request and response | ||
req := fasthttp.AcquireRequest() | ||
defer fasthttp.ReleaseRequest(req) | ||
resp := fasthttp.AcquireResponse() | ||
defer fasthttp.ReleaseResponse(resp) | ||
|
||
req.SetRequestURI(path.Join(f.url, "filter")) | ||
req.Header.SetMethod(fasthttp.MethodPost) | ||
req.Header.SetContentType("application/json") | ||
req.SetBody(payload) | ||
|
||
// Execute the request | ||
client := &fasthttp.Client{} | ||
if err := client.Do(req, resp); err != nil { | ||
logger.Error(err, "request failed") | ||
return pods | ||
} | ||
|
||
// Optionally check status code | ||
if resp.StatusCode() != fasthttp.StatusOK { | ||
logger.Error(nil, fmt.Sprintf("bad response status: %d, body: %s", resp.StatusCode(), resp.Body())) | ||
return pods | ||
} | ||
|
||
var filteredPodNames []namespacedName | ||
|
||
// filter plugin response is an array of pod full names | ||
if err := json.Unmarshal(resp.Body(), &filteredPodNames); err != nil { | ||
logger.Error(err, "external filter's response body unmarshal failed", "name", f.Name(), "resp body", resp.Body()) | ||
return pods | ||
} | ||
|
||
// filter list of given pods based on the returned list of pods | ||
podsNamesSet := map[string]bool{} | ||
|
||
for _, nn := range filteredPodNames { | ||
podsNamesSet[namespacedNameToString(nn.Name, nn.Namespace)] = true | ||
} | ||
|
||
filteredPods := make([]types.Pod, 0) | ||
|
||
for _, p := range pods { | ||
nn := p.GetPod().NamespacedName | ||
if _, exists := podsNamesSet[namespacedNameToString(nn.Name, nn.Namespace)]; exists { | ||
filteredPods = append(filteredPods, p) | ||
} | ||
} | ||
|
||
return filteredPods | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
package http | ||
|
||
// Plugin common information for external http plugin | ||
type Plugin struct { | ||
name string | ||
url string | ||
} | ||
|
||
// Name returns the plugin's name | ||
func (s *Plugin) Name() string { | ||
return s.name | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q: what's the benefit of using a non standard HTTP implementation? If performance, then we should first prove that HTTP handling is a bottleneck.