Skip to content

Commit f66be2d

Browse files
authored
adding pre-request plugin to requestcontrol layer (#1004)
* adding pre-request plugin to requestcontrol layer Signed-off-by: Nir Rozenbaum <[email protected]> * updated function names and documentation to address code review comments Signed-off-by: Nir Rozenbaum <[email protected]> * remove unused function arg Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 7bd2053 commit f66be2d

File tree

3 files changed

+60
-40
lines changed

3 files changed

+60
-40
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"context"
2323
"fmt"
2424
"math/rand"
25+
"net"
2526
"strconv"
2627
"time"
2728

@@ -54,6 +55,7 @@ func NewDirectorWithConfig(datastore datastore.Datastore, scheduler Scheduler, s
5455
datastore: datastore,
5556
scheduler: scheduler,
5657
saturationDetector: saturationDetector,
58+
preRequestPlugins: config.preRequestPlugins,
5759
postResponsePlugins: config.postResponsePlugins,
5860
}
5961
}
@@ -63,14 +65,15 @@ type Director struct {
6365
datastore datastore.Datastore
6466
scheduler Scheduler
6567
saturationDetector SaturationDetector
68+
preRequestPlugins []PreRequest
6669
postResponsePlugins []PostResponse
6770
}
6871

6972
// HandleRequest orchestrates the request lifecycle:
7073
// 1. Parses request details.
71-
// 2. Calls PreDispatch for admission control.
72-
// 3. Calls Dispatch (which calls Scheduler) if request is approved.
73-
// 4. Calls PostDispatch to populate RequestContext with results.
74+
// 2. Calls admitRequest for admission control.
75+
// 3. Calls Scheduler.Schedule if request is approved.
76+
// 4. Calls prepareRequest to populate RequestContext with results and call PreRequest plugins.
7477
//
7578
// It always returns the requestContext even in the error case, as the request context is used in error handling.
7679
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
@@ -117,42 +120,39 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
117120
Prompt: prompt,
118121
Headers: reqCtx.Request.Headers,
119122
}
120-
logger = logger.WithValues(
121-
"model", reqCtx.Model,
122-
"resolvedTargetModel", reqCtx.ResolvedTargetModel,
123-
"criticality", requestCriticality,
124-
)
123+
124+
logger = logger.WithValues("model", reqCtx.Model, "resolvedTargetModel", reqCtx.ResolvedTargetModel, "criticality", requestCriticality)
125125
ctx = log.IntoContext(ctx, logger)
126126
logger.V(logutil.DEBUG).Info("LLM request assembled")
127127

128-
// --- 2. Saturation Check ---
129-
preDispatchErr := d.PreDispatch(ctx, reqCtx, requestCriticality)
130-
if preDispatchErr != nil {
131-
return reqCtx, preDispatchErr
128+
// --- 2. Admission Control check --
129+
if err := d.admitRequest(ctx, requestCriticality); err != nil {
130+
return reqCtx, err
132131
}
133132

134-
// --- 3. Dispatch (Calls Scheduler) ---
135-
results, dispatchErr := d.Dispatch(ctx, reqCtx.SchedulingRequest)
136-
if dispatchErr != nil {
137-
return reqCtx, dispatchErr
133+
// --- 3. Call Scheduler ---
134+
results, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest)
135+
if err != nil {
136+
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
138137
}
139138

140-
// --- 4. PostDispatch (Populates RequestContext) ---
141-
// Insert target endpoint to instruct Envoy to route requests to the specified target pod.
142-
// Attach the port number.
143-
reqCtx, postDispatchErr := d.PostDispatch(ctx, reqCtx, results)
144-
if postDispatchErr != nil {
145-
return reqCtx, postDispatchErr
139+
// --- 4. Prepare Request (Populates RequestContext and call PreRequest plugins) ---
140+
// Insert target endpoint to instruct Envoy to route requests to the specified target pod and attach the port number.
141+
// Invoke PreRequest registered plugins.
142+
reqCtx, err = d.prepareRequest(ctx, reqCtx, results)
143+
if err != nil {
144+
return reqCtx, err
146145
}
147146

148147
return reqCtx, nil
149148
}
150149

151-
// PreDispatch handles admission control before dispatch.
152-
func (d *Director) PreDispatch(ctx context.Context, reqCtx *handlers.RequestContext, reqCriticality v1alpha2.Criticality) error {
150+
// admitRequest handles admission control to decide whether or not to accept the request
151+
// based on the request criticality and system saturation state.
152+
func (d *Director) admitRequest(ctx context.Context, requestCriticality v1alpha2.Criticality) error {
153153
logger := log.FromContext(ctx)
154154

155-
if reqCriticality == v1alpha2.Critical {
155+
if requestCriticality == v1alpha2.Critical {
156156
logger.V(logutil.DEBUG).Info("Critical request bypassing saturation check.")
157157
return nil
158158
}
@@ -164,24 +164,14 @@ func (d *Director) PreDispatch(ctx context.Context, reqCtx *handlers.RequestCont
164164
Msg: "system saturated, non-critical request dropped",
165165
}
166166
}
167-
return nil
168-
}
169-
170-
// Dispatch runs one or many scheduling cycles.
171-
func (d *Director) Dispatch(ctx context.Context, llmReq *schedulingtypes.LLMRequest) (*schedulingtypes.SchedulingResult, error) {
172-
var err error
173-
res, err := d.scheduler.Schedule(ctx, llmReq)
174-
if err != nil {
175-
return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
176-
}
177167

178-
return res, nil // TODO handle multi cycle result after defining the PostDispatch extension point
168+
return nil
179169
}
180170

181-
// PostDispatch populates the RequestContext based on scheduling results.
182-
func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestContext, result *schedulingtypes.SchedulingResult) (*handlers.RequestContext, error) {
171+
// prepareRequest populates the RequestContext and calls the registered PreRequest plugins
172+
// for allowing plugging customized logic based on the scheduling results.
173+
func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestContext, result *schedulingtypes.SchedulingResult) (*handlers.RequestContext, error) {
183174
logger := log.FromContext(ctx)
184-
// currently only get a single result. Will refactor to pluggably implement the PostSchedule
185175
if result == nil || len(result.ProfileResults) == 0 {
186176
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
187177
}
@@ -192,13 +182,16 @@ func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestCon
192182
if err != nil {
193183
return reqCtx, err
194184
}
185+
targetPort := int(pool.Spec.TargetPortNumber)
195186

196-
endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
187+
endpoint := net.JoinHostPort(targetPod.Address, strconv.Itoa(targetPort))
197188
logger.V(logutil.DEFAULT).Info("Request handled", "model", reqCtx.Model, "targetModel", reqCtx.ResolvedTargetModel, "endpoint", targetPod)
198189

199190
reqCtx.TargetPod = targetPod
200191
reqCtx.TargetEndpoint = endpoint
201192

193+
d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result, targetPort)
194+
202195
return reqCtx, nil
203196
}
204197

@@ -254,6 +247,16 @@ func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed
254247
return ""
255248
}
256249

250+
func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult,
251+
targetPort int) {
252+
for _, plugin := range d.preRequestPlugins {
253+
log.FromContext(ctx).V(logutil.DEBUG).Info("Running pre-request plugin", "plugin", plugin.Name())
254+
before := time.Now()
255+
plugin.PreRequest(ctx, request, schedulingResult, targetPort)
256+
metrics.RecordRequestControlPluginProcessingLatency(PreRequestPluginType, plugin.Name(), time.Since(before))
257+
}
258+
}
259+
257260
func (d *Director) runPostResponsePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
258261
for _, plugin := range d.postResponsePlugins {
259262
log.FromContext(ctx).V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name())

pkg/epp/requestcontrol/plugins.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,17 @@ import (
2525
)
2626

2727
const (
28+
PreRequestPluginType = "PreRequest"
2829
PostResponsePluginType = "PostResponse"
2930
)
3031

32+
// PreRequest is called by the director after a getting result from scheduling layer and
33+
// before a request is sent to the selected model server.
34+
type PreRequest interface {
35+
plugins.Plugin
36+
PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int)
37+
}
38+
3139
// PostResponse is called by the director after a successful response was sent.
3240
// The given pod argument is the pod that served the request.
3341
type PostResponse interface {

pkg/epp/requestcontrol/request_control_config.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,24 @@ import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
2121
// NewConfig creates a new Config object and returns its pointer.
2222
func NewConfig() *Config {
2323
return &Config{
24+
preRequestPlugins: []PreRequest{},
2425
postResponsePlugins: []PostResponse{},
2526
}
2627
}
2728

2829
// Config provides a configuration for the requestcontrol plugins.
2930
type Config struct {
31+
preRequestPlugins []PreRequest
3032
postResponsePlugins []PostResponse
3133
}
3234

35+
// WithPreRequestPlugins sets the given plugins as the PreRequest plugins.
36+
// If the Config has PreRequest plugins already, this call replaces the existing plugins with the given ones.
37+
func (c *Config) WithPreRequestPlugins(plugins ...PreRequest) *Config {
38+
c.preRequestPlugins = plugins
39+
return c
40+
}
41+
3342
// WithPostResponsePlugins sets the given plugins as the PostResponse plugins.
3443
// If the Config has PostResponse plugins already, this call replaces the existing plugins with the given ones.
3544
func (c *Config) WithPostResponsePlugins(plugins ...PostResponse) *Config {

0 commit comments

Comments
 (0)