Skip to content

Commit 72df012

Browse files
authored
Add PrepareData and Admission control plugins (#1796)
* Refactor director to split into smaller functions * Add AdmitRequest and PrepareData plugins * Add unit tests and comments * Add comments * Address review comments * Make PrepareData step time bound and execute all preparedata plugins in parallel * Update interface names based on suggestions * Update test and remove duplicate AttributeMap. Address other review comments. * Execute prepare data plugins sequentially with retries and timeout. Also added more tests and some refactoring * Update prefix match plugin to implement PrepareData plugin * Add back stashed changes. Update outdated comments. * Update function names and remove extra methods. Also don't fail request if prepare data call fails * Make prepare data plugins execution have a total timeout * Minor improvements * Revert back plugin changes
1 parent 7fae938 commit 72df012

File tree

10 files changed

+518
-29
lines changed

10 files changed

+518
-29
lines changed

pkg/epp/backend/metrics/fake.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ import (
3232

3333
// FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop.
3434
type FakePodMetrics struct {
35-
Pod *backend.Pod
36-
Metrics *MetricsState
35+
Pod *backend.Pod
36+
Metrics *MetricsState
37+
Attributes *datalayer.Attributes
3738
}
3839

3940
func (fpm *FakePodMetrics) String() string {
@@ -51,6 +52,9 @@ func (fpm *FakePodMetrics) GetMetrics() *MetricsState {
5152
func (fpm *FakePodMetrics) UpdatePod(pod *datalayer.PodInfo) {
5253
fpm.Pod = pod
5354
}
55+
func (fpm *FakePodMetrics) GetAttributes() *datalayer.Attributes {
56+
return fpm.Attributes
57+
}
5458

5559
func (*FakePodMetrics) Put(string, datalayer.Cloneable) {}
5660
func (*FakePodMetrics) Get(string) (datalayer.Cloneable, bool) { return nil, false }

pkg/epp/backend/metrics/pod_metrics.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ func (pm *podMetrics) stopRefreshLoop() {
126126
func (*podMetrics) Put(string, datalayer.Cloneable) {}
127127
func (*podMetrics) Get(string) (datalayer.Cloneable, bool) { return nil, false }
128128
func (*podMetrics) Keys() []string { return nil }
129+
func (*podMetrics) GetAttributes() *datalayer.Attributes {
130+
return nil
131+
}
129132

130133
func (pm *podMetrics) UpdateMetrics(updated *MetricsState) {
131134
updated.UpdateTime = time.Now()

pkg/epp/datalayer/endpoint.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
type EndpointPodState interface {
2626
GetPod() *PodInfo
2727
UpdatePod(*PodInfo)
28+
GetAttributes() *Attributes
2829
}
2930

3031
// EndpointMetricsState allows management of the Metrics related attributes.
@@ -98,6 +99,10 @@ func (srv *ModelServer) Keys() []string {
9899
return srv.attributes.Keys()
99100
}
100101

102+
func (srv *ModelServer) GetAttributes() *Attributes {
103+
return srv.attributes
104+
}
105+
101106
func (srv *ModelServer) Clone() *ModelServer {
102107
clone := &ModelServer{
103108
attributes: srv.attributes.Clone(),

pkg/epp/requestcontrol/director.go

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
3333
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3434
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
35+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3536
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
3637
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
3738
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
@@ -41,6 +42,11 @@ import (
4142
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
4243
)
4344

45+
const (
46+
// TODO: Make these configurable per plugin via config.
47+
prepareDataTimeout = 400 * time.Millisecond
48+
)
49+
4450
// Datastore defines the interface required by the Director.
4551
type Datastore interface {
4652
PoolGet() (*v1.InferencePool, error)
@@ -89,16 +95,28 @@ type Director struct {
8995
defaultPriority int
9096
}
9197

92-
// HandleRequest orchestrates the request lifecycle.
93-
// It always returns the requestContext even in the error case, as the request context is used in error handling.
94-
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
95-
logger := log.FromContext(ctx)
98+
// getInferenceObjective fetches the inferenceObjective from the datastore otherwise creates a new one based on reqCtx.
99+
func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.RequestContext) *v1alpha2.InferenceObjective {
100+
infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
101+
if infObjective == nil {
102+
log.FromContext(ctx).V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
103+
infObjective = &v1alpha2.InferenceObjective{
104+
Spec: v1alpha2.InferenceObjectiveSpec{
105+
Priority: &d.defaultPriority,
106+
},
107+
}
108+
} else if infObjective.Spec.Priority == nil {
109+
// Default to 0 if not specified.
110+
infObjective.Spec.Priority = &d.defaultPriority
111+
}
112+
return infObjective
113+
}
96114

97-
// Parse Request, Resolve Target Models, and Determine Parameters
115+
// resolveTargetModel is a helper to update reqCtx with target model based on request.
116+
func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
98117
requestBodyMap := reqCtx.Request.Body
99118
var ok bool
100119
reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string)
101-
102120
if !ok {
103121
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"}
104122
}
@@ -107,24 +125,28 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
107125
reqCtx.TargetModelName = reqCtx.IncomingModelName
108126
}
109127
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
128+
return reqCtx, nil
129+
}
110130

131+
// HandleRequest orchestrates the request lifecycle.
132+
// It always returns the requestContext even in the error case, as the request context is used in error handling.
133+
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
134+
logger := log.FromContext(ctx)
135+
136+
// Resolve target model and update req context.
137+
reqCtx, err := d.resolveTargetModel(reqCtx)
138+
if err != nil {
139+
return reqCtx, err
140+
}
141+
142+
// Parse request body.
111143
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
112144
if err != nil {
113145
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
114146
}
115147

116-
infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
117-
if infObjective == nil {
118-
logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
119-
infObjective = &v1alpha2.InferenceObjective{
120-
Spec: v1alpha2.InferenceObjectiveSpec{
121-
Priority: &d.defaultPriority,
122-
},
123-
}
124-
} else if infObjective.Spec.Priority == nil {
125-
// Default to 0 if not specified.
126-
infObjective.Spec.Priority = &d.defaultPriority
127-
}
148+
// Parse inference objective.
149+
infObjective := d.getInferenceObjective(ctx, reqCtx)
128150

129151
// Prepare LLMRequest (needed for both saturation detection and Scheduler)
130152
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
@@ -144,13 +166,25 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
144166
if len(candidatePods) == 0 {
145167
return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"}
146168
}
147-
148169
if err := d.admissionController.Admit(ctx, reqCtx, candidatePods, *infObjective.Spec.Priority); err != nil {
149170
logger.V(logutil.DEFAULT).Info("Request rejected by admission control", "error", err)
150171
return reqCtx, err
151172
}
173+
snapshotOfCandidatePods := d.toSchedulerPodMetrics(candidatePods)
152174

153-
result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, d.toSchedulerPodMetrics(candidatePods))
175+
// Prepare per request data by running PrepareData plugins.
176+
if d.runPrepareDataPlugins(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods) != nil {
177+
// Don't fail the request if PrepareData plugins fail.
178+
logger.V(logutil.DEFAULT).Error(err, "failed to prepare per request data")
179+
}
180+
181+
// Run admit request plugins
182+
if !d.runAdmissionPlugins(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods) {
183+
logger.V(logutil.DEFAULT).Info("Request cannot be admitted")
184+
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "request cannot be admitted"}
185+
}
186+
187+
result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods)
154188
if err != nil {
155189
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
156190
}
@@ -244,7 +278,11 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
244278
func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Pod {
245279
pm := make([]schedulingtypes.Pod, len(pods))
246280
for i, pod := range pods {
247-
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone()}
281+
if pod.GetAttributes() != nil {
282+
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()}
283+
} else {
284+
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()}
285+
}
248286
}
249287

250288
return pm
@@ -315,6 +353,29 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
315353
}
316354
}
317355

356+
// TODO: Execute plugins in parallel once DAG execution is supported.
357+
// runPrepareDataPlugins executes PrepareDataPlugins sequentially.
358+
func (d *Director) runPrepareDataPlugins(ctx context.Context,
359+
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
360+
return prepareDataPluginsWithTimeout(
361+
prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
362+
363+
}
364+
365+
func (d *Director) runAdmissionPlugins(ctx context.Context,
366+
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) bool {
367+
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
368+
for _, plugin := range d.requestControlPlugins.admissionPlugins {
369+
loggerDebug.Info("Running AdmitRequest plugin", "plugin", plugin.TypedName())
370+
if denyReason := plugin.AdmitRequest(ctx, request, pods); denyReason != nil {
371+
loggerDebug.Info("AdmitRequest plugin denied the request", "plugin", plugin.TypedName(), "reason", denyReason.Error())
372+
return false
373+
}
374+
loggerDebug.Info("Completed running AdmitRequest plugin successfully", "plugin", plugin.TypedName())
375+
}
376+
return true
377+
}
378+
318379
func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
319380
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
320381
for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {

0 commit comments

Comments
 (0)