Skip to content

feat: Add external plugins support - DO NOT MERGE #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/prometheus/client_golang v1.22.0
github.com/redis/go-redis/v9 v9.7.3
github.com/stretchr/testify v1.10.0
github.com/valyala/fasthttp v1.62.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: what's the benefit of using a non standard HTTP implementation? If performance, then we should first prove that HTTP handling is a bottleneck.

go.uber.org/zap v1.27.0
google.golang.org/grpc v1.72.0
k8s.io/apimachinery v0.33.1
Expand All @@ -26,6 +27,7 @@ require (

require (
cel.dev/expr v0.20.0 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect
github.com/beorn7/perks v1.0.1 // indirect
Expand Down Expand Up @@ -56,6 +58,7 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
Expand All @@ -69,6 +72,7 @@ require (
github.com/spf13/cobra v1.9.1 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/stoewer/go-strcase v1.3.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect
Expand Down
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
cel.dev/expr v0.20.0 h1:OunBvVCfvpWlt4dN7zg3FM6TDkzOePe1+foGJ9AXeeI=
cel.dev/expr v0.20.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI=
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA=
Expand Down Expand Up @@ -148,8 +150,14 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0=
github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
Expand Down
96 changes: 86 additions & 10 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package config

import (
"encoding/json"

"github.com/go-logr/logr"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
)
Expand Down Expand Up @@ -52,15 +54,49 @@ const (

prefixScorerBlockSizeEnvKey = "PREFIX_SCORER_BLOCK_SIZE"
prefixScorerBlockSizeDefault = 256

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point the external process plugin is for demo only, so I would refrain from introducing and using all these environment variables (which would be eliminated by the config changes anyway).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a demo, you can just enable (or not) a fixed type plugins (e.g., filter or scorer) at a well known address (e.g., localhost:8000/8088 and run their container(s) in the same Pod)

externalPrefix = "EXTERNAL_"
httpPrefix = "HTTP_"

preSchedulers = "PRE_SCHEDULERS"
filters = "FILTERS"
scorers = "SCORERS"
postSchedulers = "POST_SCHEDULERS"

// EXTERNAL_HTTP_PRE_SCHEDULERS
// EXTERNAL_PREFILL_HTTP_PRE_SCHEDULERS
// EXTERNAL_HTTP_FILTERS
// EXTERNAL_PREFILL_HTTP_FILTERS
// EXTERNAL_HTTP_SCORERS
// EXTERNAL_PREFILL_HTTP_SCORERS
// EXTERNAL_HTTP_POST_SCHEDULERS
// EXTERNAL_PREFILL_HTTP_POST_SCHEDULERS
)

// ExternalPluginInfo configuration of an external plugin
type ExternalPluginInfo struct {
Name string `json:"name"`
URL string `json:"url"`
Weight int `json:"weight"`
}

// ExternalPlugins contains all types of external plugins configuration
type ExternalPlugins struct {
PreSchedulers []ExternalPluginInfo
Filters []ExternalPluginInfo
Scorers []ExternalPluginInfo
PostSchedulers []ExternalPluginInfo
}

// Config contains scheduler configuration, currently configuration is loaded from environment variables
type Config struct {
DecodeSchedulerPlugins map[string]int
PrefillSchedulerPlugins map[string]int
PDEnabled bool
PDThreshold int
PrefixBlockSize int
DecodeSchedulerPlugins map[string]int
PrefillSchedulerPlugins map[string]int
DecodeSchedulerExternalPlugins ExternalPlugins
PrefillSchedulerExternalPlugins ExternalPlugins
PDEnabled bool
PDThreshold int
PrefixBlockSize int
}

// LoadConfig loads configuration from environment variables and returns a new instance of Config
Expand All @@ -72,12 +108,29 @@ func LoadConfig(logger logr.Logger) *Config {
GIEKVCacheUtilizationScorerName, GIEQueueScorerName, GIEPrefixScorerName,
}

// load external plugins for decode and prefill schedulers
prefillSchedulerExternalPlugins := ExternalPlugins{
PreSchedulers: loadExternalPluginsInfo(logger, httpPrefix, "", preSchedulers),
Filters: loadExternalPluginsInfo(logger, httpPrefix, "", filters),
Scorers: loadExternalPluginsInfo(logger, httpPrefix, "", scorers),
PostSchedulers: loadExternalPluginsInfo(logger, httpPrefix, "", postSchedulers),
}

decodeSchedulerExternalPlugins := ExternalPlugins{
PreSchedulers: loadExternalPluginsInfo(logger, httpPrefix, prefillPrefix, preSchedulers),
Filters: loadExternalPluginsInfo(logger, httpPrefix, prefillPrefix, filters),
Scorers: loadExternalPluginsInfo(logger, httpPrefix, prefillPrefix, scorers),
PostSchedulers: loadExternalPluginsInfo(logger, httpPrefix, prefillPrefix, postSchedulers),
}

return &Config{
DecodeSchedulerPlugins: loadPluginInfo(logger, false, pluginNames),
PrefillSchedulerPlugins: loadPluginInfo(logger, true, pluginNames),
PDEnabled: env.GetEnvString(pdEnabledEnvKey, "false", logger) == "true",
PDThreshold: env.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger),
PrefixBlockSize: env.GetEnvInt(prefixScorerBlockSizeEnvKey, prefixScorerBlockSizeDefault, logger),
DecodeSchedulerPlugins: loadPluginInfo(logger, false, pluginNames),
PrefillSchedulerPlugins: loadPluginInfo(logger, true, pluginNames),
DecodeSchedulerExternalPlugins: prefillSchedulerExternalPlugins,
PrefillSchedulerExternalPlugins: decodeSchedulerExternalPlugins,
PDEnabled: env.GetEnvString(pdEnabledEnvKey, "false", logger) == "true",
PDThreshold: env.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger),
PrefixBlockSize: env.GetEnvInt(prefixScorerBlockSizeEnvKey, prefixScorerBlockSizeDefault, logger),
}
}

Expand Down Expand Up @@ -107,3 +160,26 @@ func loadPluginInfo(logger logr.Logger, prefill bool, pluginNames []string) map[

return result
}

// loadExternalPluginsInfo loads configuration of external plugins for the given scheduler type and the given plugins type
//
//nolint:unparam // future: protocol will support more values (grpc, wasm, etc.)
func loadExternalPluginsInfo(logger logr.Logger, protocol string, schedulerType string, pluginType string) []ExternalPluginInfo {
var plugins []ExternalPluginInfo

envVarName := externalPrefix + protocol + schedulerType + pluginType
envVarRawValue := env.GetEnvString(envVarName, "", logger)

if envVarRawValue == "" {
logger.Info("Environment variable is not defined", "var", envVarName)
return plugins
}

if err := json.Unmarshal([]byte(envVarRawValue), &plugins); err != nil {
logger.Info("Error in environment variable unmarshaling", "error", err, "variable", envVarName, "value", envVarRawValue)
return plugins
}

logger.Info("External plugin loaded", "type", pluginType, "plugins", plugins)
return plugins
}
64 changes: 61 additions & 3 deletions pkg/scheduling/pd/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"

"github.com/llm-d/llm-d-inference-scheduler/pkg/config"
externalhttp "github.com/llm-d/llm-d-inference-scheduler/pkg/scheduling/plugins/external/http"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import order seems wrong

"github.com/llm-d/llm-d-inference-scheduler/pkg/scheduling/plugins/filter"
"github.com/llm-d/llm-d-inference-scheduler/pkg/scheduling/plugins/scorer"
)
Expand Down Expand Up @@ -70,13 +71,13 @@ func NewScheduler(ctx context.Context, schedulerConfig *config.Config, ds Datast

scheduler.prefill = scheduling.NewSchedulerWithConfig(
ds,
scheduler.generateSchedulerConfig(ctx, schedulerConfig.PrefillSchedulerPlugins,
scheduler.generateSchedulerConfig(ctx, schedulerConfig.PrefillSchedulerPlugins, schedulerConfig.PrefillSchedulerExternalPlugins,
&filter.PrefillFilter{}),
)

scheduler.decode = scheduling.NewSchedulerWithConfig(
ds,
scheduler.generateSchedulerConfig(ctx, schedulerConfig.DecodeSchedulerPlugins,
scheduler.generateSchedulerConfig(ctx, schedulerConfig.DecodeSchedulerPlugins, schedulerConfig.DecodeSchedulerExternalPlugins,
&filter.DecodeFilter{}),
)

Expand Down Expand Up @@ -212,7 +213,57 @@ func (s *Scheduler) pluginsFromConfig(ctx context.Context, pluginsConfig map[str
return plugins
}

func (s *Scheduler) generateSchedulerConfig(ctx context.Context, pluginsConfig map[string]int, extraFilters ...plugins.Filter) *scheduling.SchedulerConfig {
func externalFiltersFromConfig(ctx context.Context, info []config.ExternalPluginInfo) []plugins.Filter {
logger := log.FromContext(ctx)
filters := make([]plugins.Filter, 0)

for _, extPluginInfo := range info {
filters = append(filters, externalhttp.NewFilter(ctx, extPluginInfo.Name, extPluginInfo.URL))
}

logger.Info(fmt.Sprintf("Created %d external filters", len(filters)))
return filters
}

func externalPreSchedulesFromConfig(ctx context.Context, info []config.ExternalPluginInfo) []plugins.PreSchedule {
logger := log.FromContext(ctx)
preSchedules := []plugins.PreSchedule{}

for _, extPluginInfo := range info {
preSchedules = append(preSchedules, externalhttp.NewPreSchedule(ctx, extPluginInfo.Name, extPluginInfo.URL))
}

logger.Info(fmt.Sprintf("Created %d external pre-schedules", len(preSchedules)))
return preSchedules
}

func externalPostSchedulesFromConfig(ctx context.Context, info []config.ExternalPluginInfo) []plugins.PostSchedule {
logger := log.FromContext(ctx)
postSchedules := []plugins.PostSchedule{}

for _, extPluginInfo := range info {
postSchedules = append(postSchedules, externalhttp.NewPostSchedule(ctx, extPluginInfo.Name, extPluginInfo.URL))
}

logger.Info(fmt.Sprintf("Created %d external post-schedules", len(postSchedules)))
return postSchedules
}

func externalScorersFromConfig(ctx context.Context, info []config.ExternalPluginInfo) []*giescorer.WeightedScorer {
logger := log.FromContext(ctx)
scorers := []*giescorer.WeightedScorer{}

for _, extPluginInfo := range info {
scorers = append(scorers, giescorer.NewWeightedScorer(externalhttp.NewScorer(ctx, extPluginInfo.Name, extPluginInfo.URL), extPluginInfo.Weight))
}

logger.Info(fmt.Sprintf("Created %d external scorers", len(scorers)))
return scorers
}

func (s *Scheduler) generateSchedulerConfig(ctx context.Context, pluginsConfig map[string]int,
externalPlugins config.ExternalPlugins, extraFilters ...plugins.Filter) *scheduling.SchedulerConfig {

thePlugins := s.pluginsFromConfig(ctx, pluginsConfig)
preSchedulePlugins := []plugins.PreSchedule{}
filters := []plugins.Filter{}
Expand Down Expand Up @@ -240,6 +291,13 @@ func (s *Scheduler) generateSchedulerConfig(ctx context.Context, pluginsConfig m
}
}

// add external plugins
preSchedulePlugins = append(preSchedulePlugins, externalPreSchedulesFromConfig(ctx, externalPlugins.PreSchedulers)...)
filters = append(filters, externalFiltersFromConfig(ctx, externalPlugins.Filters)...)
scorers = append(scorers, externalScorersFromConfig(ctx, externalPlugins.Scorers)...)
postSchedulePlugins = append(postSchedulePlugins, externalPostSchedulesFromConfig(ctx, externalPlugins.PostSchedulers)...)
// postResponsePlugins = append(postResponsePlugins, postResponse)

return scheduling.NewSchedulerConfig().
WithPreSchedulePlugins(preSchedulePlugins...).
WithFilters(filters...).
Expand Down
92 changes: 92 additions & 0 deletions pkg/scheduling/plugins/external/http/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Package http contains all types of external http plugins
package http

import (
"context"
"encoding/json"
"fmt"
"path"

"github.com/valyala/fasthttp"

"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// Filter implementation of the external http filter
type Filter struct {
Plugin
}

var _ plugins.Filter = &Filter{} // validate interface conformance

// NewFilter creates a new instance of external http filter based on the given parameters
func NewFilter(_ context.Context, name string, url string) plugins.Filter {
return &Filter{Plugin{name: name, url: url}}
}

// Filter filters the given list of pods
func (f *Filter) Filter(schedContext *types.SchedulingContext, pods []types.Pod) []types.Pod {
logger := log.FromContext(schedContext).WithName(f.Name())

// Create filter http request payload based on the given data
filterPayload := newFilterPayload(schedContext, pods)
payload, err := json.Marshal(filterPayload)

if err != nil {
logger.Error(err, "Failed to marshal scheduling context, filter will be skipped")
return pods
}

// Create a new fasthttp request and response
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)

req.SetRequestURI(path.Join(f.url, "filter"))
req.Header.SetMethod(fasthttp.MethodPost)
req.Header.SetContentType("application/json")
req.SetBody(payload)

// Execute the request
client := &fasthttp.Client{}
if err := client.Do(req, resp); err != nil {
logger.Error(err, "request failed")
return pods
}

// Optionally check status code
if resp.StatusCode() != fasthttp.StatusOK {
logger.Error(nil, fmt.Sprintf("bad response status: %d, body: %s", resp.StatusCode(), resp.Body()))
return pods
}

var filteredPodNames []namespacedName

// filter plugin response is an array of pod full names
if err := json.Unmarshal(resp.Body(), &filteredPodNames); err != nil {
logger.Error(err, "external filter's response body unmarshal failed", "name", f.Name(), "resp body", resp.Body())
return pods
}

// filter list of given pods based on the returned list of pods
podsNamesSet := map[string]bool{}

for _, nn := range filteredPodNames {
podsNamesSet[namespacedNameToString(nn.Name, nn.Namespace)] = true
}

filteredPods := make([]types.Pod, 0)

for _, p := range pods {
nn := p.GetPod().NamespacedName
if _, exists := podsNamesSet[namespacedNameToString(nn.Name, nn.Namespace)]; exists {
filteredPods = append(filteredPods, p)
}
}

return filteredPods
}
12 changes: 12 additions & 0 deletions pkg/scheduling/plugins/external/http/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package http

// Plugin common information for external http plugin
type Plugin struct {
name string
url string
}

// Name returns the plugin's name
func (s *Plugin) Name() string {
return s.name
}
Loading