Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func (r *Runner) Run(ctx context.Context) error {

scheduler := scheduling.NewSchedulerWithConfig(r.schedulerConfig)

saturationDetector := saturationdetector.NewDetector(sdConfig, datastore, setupLog)
saturationDetector := saturationdetector.NewDetector(sdConfig, setupLog)

director := requestcontrol.NewDirectorWithConfig(datastore, scheduler, saturationDetector, r.requestControlConfig)

Expand Down
63 changes: 34 additions & 29 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import (

"sigs.k8s.io/controller-runtime/pkg/log"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
Expand All @@ -42,30 +42,38 @@ import (
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)

// Datastore defines the interface required by the Director.
type Datastore interface {
PoolGet() (*v1.InferencePool, error)
ObjectiveGet(modelName string) *v1alpha2.InferenceObjective
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
}

// Scheduler defines the interface required by the Director for scheduling.
type Scheduler interface {
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error)
}

// SaturationDetector provides a signal indicating whether the backends are considered saturated.
type SaturationDetector interface {
IsSaturated(ctx context.Context) bool
IsSaturated(ctx context.Context, candidatePods []backendmetrics.PodMetrics) bool
}

// NewDirectorWithConfig creates a new Director instance with all dependencies.
func NewDirectorWithConfig(datastore datastore.Datastore, scheduler Scheduler, saturationDetector SaturationDetector, config *Config) *Director {
func NewDirectorWithConfig(datastore Datastore, scheduler Scheduler, saturationDetector SaturationDetector, config *Config) *Director {
return &Director{
datastore: datastore,
scheduler: scheduler,
saturationDetector: saturationDetector,
preRequestPlugins: config.preRequestPlugins,
postResponsePlugins: config.postResponsePlugins,
defaultPriority: 0, // define default priority explicitly
}
}

// Director orchestrates the request handling flow, including scheduling.
type Director struct {
datastore datastore.Datastore
datastore Datastore
scheduler Scheduler
saturationDetector SaturationDetector
preRequestPlugins []PreRequest
Expand All @@ -76,17 +84,12 @@ type Director struct {
defaultPriority int
}

// HandleRequest orchestrates the request lifecycle:
// 1. Parses request details.
// 2. Calls admitRequest for admission control.
// 3. Calls Scheduler.Schedule if request is approved.
// 4. Calls prepareRequest to populate RequestContext with result and call PreRequest plugins.
//
// HandleRequest orchestrates the request lifecycle.
// It always returns the requestContext even in the error case, as the request context is used in error handling.
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx)

// --- 1. Parse Request, Resolve Target Models, and Determine Parameters ---
// Parse Request, Resolve Target Models, and Determine Parameters
requestBodyMap := reqCtx.Request.Body
var ok bool
reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string)
Expand Down Expand Up @@ -130,22 +133,23 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
ctx = log.IntoContext(ctx, logger)
logger.V(logutil.DEBUG).Info("LLM request assembled")

// --- 2. Admission Control check --
if err := d.admitRequest(ctx, *infObjective.Spec.Priority, reqCtx.FairnessID); err != nil {
return reqCtx, err
}

// --- 3. Call Scheduler (with the relevant candidate pods) ---
// Get candidate pods for scheduling
candidatePods := d.getCandidatePodsForScheduling(ctx, reqCtx.Request.Metadata)
if len(candidatePods) == 0 {
return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"}
}
result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, candidatePods)

// Admission Control check
if err := d.admitRequest(ctx, candidatePods, *infObjective.Spec.Priority, reqCtx.FairnessID); err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that this dedupes the candidatePods == 0 check that was in the admitRequest func

return reqCtx, err
}

result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, d.toSchedulerPodMetrics(candidatePods))
if err != nil {
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
}

// --- 4. Prepare Request (Populates RequestContext and call PreRequest plugins) ---
// Prepare Request (Populates RequestContext and call PreRequest plugins)
// Insert target endpoint to instruct Envoy to route requests to the specified target pod and attach the port number.
// Invoke PreRequest registered plugins.
reqCtx, err = d.prepareRequest(ctx, reqCtx, result)
Expand All @@ -158,20 +162,21 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo

// admitRequest handles admission control to decide whether or not to accept the request
// based on the request priority and system saturation state.
func (d *Director) admitRequest(ctx context.Context, requestPriority int, fairnessID string) error {
logger := log.FromContext(ctx)
func (d *Director) admitRequest(ctx context.Context, candidatePods []backendmetrics.PodMetrics,
requestPriority int, fairnessID string) error {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)

logger.V(logutil.TRACE).Info("Entering Flow Control", "priority", requestPriority, "fairnessID", fairnessID)
loggerTrace.Info("Entering Flow Control", "priority", requestPriority, "fairnessID", fairnessID)

// This will be removed in favor of a more robust implementation (Flow Control) in the very near future.
// TODO: Make this a configurable value.
// Tracking issue https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1347
if requestPriority >= 0 {
logger.V(logutil.TRACE).Info("Non-sheddable request bypassing saturation check.")
loggerTrace.Info("Non-sheddable request bypassing saturation check.")
return nil
}

if d.saturationDetector.IsSaturated(ctx) { // Assuming non-nil Saturation Detector
if d.saturationDetector.IsSaturated(ctx, candidatePods) {
return errutil.Error{
Code: errutil.InferencePoolResourceExhausted,
Msg: "system saturated, sheddable request dropped",
Expand All @@ -187,21 +192,21 @@ func (d *Director) admitRequest(ctx context.Context, requestPriority int, fairne
// Snapshot pod metrics from the datastore to:
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles.
func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMetadata map[string]any) []schedulingtypes.Pod {
func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMetadata map[string]any) []backendmetrics.PodMetrics {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)

subsetMap, found := requestMetadata[metadata.SubsetFilterNamespace].(map[string]any)
if !found {
return d.toSchedulerPodMetrics(d.datastore.PodList(backendmetrics.AllPodsPredicate))
return d.datastore.PodList(backendmetrics.AllPodsPredicate)
}

// Check if endpoint key is present in the subset map and ensure there is at least one value
endpointSubsetList, found := subsetMap[metadata.SubsetFilterKey].([]any)
if !found {
return d.toSchedulerPodMetrics(d.datastore.PodList(backendmetrics.AllPodsPredicate))
return d.datastore.PodList(backendmetrics.AllPodsPredicate)
} else if len(endpointSubsetList) == 0 {
loggerTrace.Info("found empty subset filter in request metadata, filtering all pods")
return []schedulingtypes.Pod{}
return []backendmetrics.PodMetrics{}
}

// Create a map of endpoint addresses for easy lookup
Expand All @@ -224,7 +229,7 @@ func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMet

loggerTrace.Info("filtered candidate pods by subset filtering", "podTotalCount", podTotalCount, "filteredCount", len(podFitleredList))

return d.toSchedulerPodMetrics(podFitleredList)
return podFitleredList
}

// prepareRequest populates the RequestContext and calls the registered PreRequest plugins
Expand Down
103 changes: 37 additions & 66 deletions pkg/epp/requestcontrol/director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client/fake"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
Expand All @@ -53,7 +54,7 @@ type mockSaturationDetector struct {
isSaturated bool
}

func (m *mockSaturationDetector) IsSaturated(_ context.Context) bool {
func (m *mockSaturationDetector) IsSaturated(_ context.Context, _ []backendmetrics.PodMetrics) bool {
return m.isSaturated
}

Expand All @@ -66,6 +67,23 @@ func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMReques
return m.scheduleResults, m.scheduleErr
}

type mockDatastore struct {
pods []backendmetrics.PodMetrics
}

func (ds *mockDatastore) PoolGet() (*v1.InferencePool, error) { return nil, nil }
func (ds *mockDatastore) ObjectiveGet(_ string) *v1alpha2.InferenceObjective { return nil }
func (ds *mockDatastore) PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics {
res := []backendmetrics.PodMetrics{}
for _, pod := range ds.pods {
if predicate(pod) {
res = append(res, pod)
}
}

return res
}

func TestDirector_HandleRequest(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

Expand Down Expand Up @@ -450,119 +468,72 @@ func TestGetCandidatePodsForScheduling(t *testing.T) {
}
}

testInput := []*corev1.Pod{
{
ObjectMeta: metav1.ObjectMeta{
Name: "pod1",
},
Status: corev1.PodStatus{
PodIP: "10.0.0.1",
},
},
{
ObjectMeta: metav1.ObjectMeta{
Name: "pod2",
},
Status: corev1.PodStatus{
PodIP: "10.0.0.2",
},
},
}

outputPod1 := &backend.Pod{
pod1 := &backend.Pod{
NamespacedName: types.NamespacedName{Name: "pod1"},
Address: "10.0.0.1",
Labels: map[string]string{},
}

outputPod2 := &backend.Pod{
pod2 := &backend.Pod{
NamespacedName: types.NamespacedName{Name: "pod2"},
Address: "10.0.0.2",
Labels: map[string]string{},
}

testInput := []backendmetrics.PodMetrics{
&backendmetrics.FakePodMetrics{Pod: pod1},
&backendmetrics.FakePodMetrics{Pod: pod2},
}

tests := []struct {
name string
metadata map[string]any
output []schedulingtypes.Pod
output []backendmetrics.PodMetrics
}{
{
name: "SubsetFilter, filter not present — return all pods",
metadata: map[string]any{},
output: []schedulingtypes.Pod{
&schedulingtypes.PodMetrics{
Pod: outputPod1,
MetricsState: backendmetrics.NewMetricsState(),
},
&schedulingtypes.PodMetrics{
Pod: outputPod2,
MetricsState: backendmetrics.NewMetricsState(),
},
},
output: testInput,
},
{
name: "SubsetFilter, namespace present filter not present — return all pods",
metadata: map[string]any{metadata.SubsetFilterNamespace: map[string]any{}},
output: []schedulingtypes.Pod{
&schedulingtypes.PodMetrics{
Pod: outputPod1,
MetricsState: backendmetrics.NewMetricsState(),
},
&schedulingtypes.PodMetrics{
Pod: outputPod2,
MetricsState: backendmetrics.NewMetricsState(),
},
},
output: testInput,
},
{
name: "SubsetFilter, filter present with empty list — return error",
metadata: makeFilterMetadata([]any{}),
output: []schedulingtypes.Pod{},
output: []backendmetrics.PodMetrics{},
},
{
name: "SubsetFilter, subset with one matching pod",
metadata: makeFilterMetadata([]any{"10.0.0.1"}),
output: []schedulingtypes.Pod{
&schedulingtypes.PodMetrics{
Pod: outputPod1,
MetricsState: backendmetrics.NewMetricsState(),
output: []backendmetrics.PodMetrics{
&backendmetrics.FakePodMetrics{
Pod: pod1,
},
},
},
{
name: "SubsetFilter, subset with multiple matching pods",
metadata: makeFilterMetadata([]any{"10.0.0.1", "10.0.0.2", "10.0.0.3"}),
output: []schedulingtypes.Pod{
&schedulingtypes.PodMetrics{
Pod: outputPod1,
MetricsState: backendmetrics.NewMetricsState(),
},
&schedulingtypes.PodMetrics{
Pod: outputPod2,
MetricsState: backendmetrics.NewMetricsState(),
},
},
output: testInput,
},
{
name: "SubsetFilter, subset with no matching pods",
metadata: makeFilterMetadata([]any{"10.0.0.3"}),
output: []schedulingtypes.Pod{},
output: []backendmetrics.PodMetrics{},
},
}

pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
ds := datastore.NewDatastore(t.Context(), pmf)
for _, testPod := range testInput {
ds.PodUpdateOrAddIfNotExist(testPod)
}

ds := &mockDatastore{pods: testInput}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
director := NewDirectorWithConfig(ds, &mockScheduler{}, &mockSaturationDetector{}, NewConfig())

got := director.getCandidatePodsForScheduling(context.Background(), test.metadata)

diff := cmp.Diff(test.output, got, cmpopts.SortSlices(func(a, b schedulingtypes.Pod) bool {
diff := cmp.Diff(test.output, got, cmpopts.SortSlices(func(a, b backendmetrics.PodMetrics) bool {
return a.GetPod().NamespacedName.String() < b.GetPod().NamespacedName.String()
}))
if diff != "" {
Expand Down
Loading