Skip to content

Commit 3322cf8

Browse files
authored
convert subset filter from a plugin to logic in director (#1088)
* convert subset filter from a plugin to logic in director Signed-off-by: Nir Rozenbaum <[email protected]> * replace interface{} with any Signed-off-by: Nir Rozenbaum <[email protected]> * make linter happy Signed-off-by: Nir Rozenbaum <[email protected]> * address code review comments Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 1ca6875 commit 3322cf8

File tree

23 files changed

+302
-346
lines changed

23 files changed

+302
-346
lines changed

cmd/epp/runner/runner.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ import (
4747
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector"
4848
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
4949
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
50-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/filter"
5150
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
5251
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
5352
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
@@ -292,7 +291,6 @@ func (r *Runner) initializeScheduler() (*scheduling.Scheduler, error) {
292291
kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog)
293292

294293
schedulerProfile := framework.NewSchedulerProfile().
295-
WithFilters(filter.NewSubsetFilter()).
296294
WithScorers(framework.NewWeightedScorer(scorer.NewQueueScorer(), queueScorerWeight),
297295
framework.NewWeightedScorer(scorer.NewKVCacheScorer(), kvCacheScorerWeight)).
298296
WithPicker(picker.NewMaxScorePicker())

pkg/bbr/handlers/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ type streamedBody struct {
118118
func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, streamedBody *streamedBody, logger logr.Logger) ([]*extProcPb.ProcessingResponse, error) {
119119
loggerVerbose := logger.V(logutil.VERBOSE)
120120

121-
var requestBody map[string]interface{}
121+
var requestBody map[string]any
122122
if s.streaming {
123123
streamedBody.body = append(streamedBody.body, body.Body...)
124124
// In the stream case, we can receive multiple request bodies.

pkg/epp/backend/metrics/metrics_state.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"time"
2222
)
2323

24-
// newMetricsState initializes a new MetricsState and returns its pointer.
25-
func newMetricsState() *MetricsState {
24+
// NewMetricsState initializes a new MetricsState and returns its pointer.
25+
func NewMetricsState() *MetricsState {
2626
return &MetricsState{
2727
ActiveModels: make(map[string]int),
2828
WaitingModels: make(map[string]int),

pkg/epp/backend/metrics/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.
5151
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
5252
}
5353
pm.pod.Store(pod)
54-
pm.metrics.Store(newMetricsState())
54+
pm.metrics.Store(NewMetricsState())
5555

5656
pm.startRefreshLoop(parentCtx)
5757
return pm

pkg/epp/handlers/response.go

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,15 @@ const (
3434
)
3535

3636
// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
37-
func (s *StreamingServer) HandleResponseBody(
38-
ctx context.Context,
39-
reqCtx *RequestContext,
40-
response map[string]interface{},
41-
) (*RequestContext, error) {
37+
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) {
4238
logger := log.FromContext(ctx)
4339
responseBytes, err := json.Marshal(response)
4440
if err != nil {
4541
logger.V(logutil.DEFAULT).Error(err, "error marshalling responseBody")
4642
return reqCtx, err
4743
}
4844
if response["usage"] != nil {
49-
usg := response["usage"].(map[string]interface{})
45+
usg := response["usage"].(map[string]any)
5046
usage := Usage{
5147
PromptTokens: int(usg["prompt_tokens"].(float64)),
5248
CompletionTokens: int(usg["completion_tokens"].(float64)),
@@ -68,11 +64,7 @@ func (s *StreamingServer) HandleResponseBody(
6864
}
6965

7066
// The function is to handle streaming response if the modelServer is streaming.
71-
func (s *StreamingServer) HandleResponseBodyModelStreaming(
72-
ctx context.Context,
73-
reqCtx *RequestContext,
74-
responseText string,
75-
) {
67+
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
7668
if strings.Contains(responseText, streamingEndMsg) {
7769
resp := parseRespForUsage(ctx, responseText)
7870
reqCtx.Usage = resp.Usage
@@ -160,10 +152,7 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con
160152
//
161153
// If include_usage is not included in the request, `data: [DONE]` is returned separately, which
162154
// indicates end of streaming.
163-
func parseRespForUsage(
164-
ctx context.Context,
165-
responseText string,
166-
) ResponseBody {
155+
func parseRespForUsage(ctx context.Context, responseText string) ResponseBody {
167156
response := ResponseBody{}
168157
logger := log.FromContext(ctx)
169158

pkg/epp/handlers/response_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func TestHandleResponseBody(t *testing.T) {
8686
if reqCtx == nil {
8787
reqCtx = &RequestContext{}
8888
}
89-
var responseMap map[string]interface{}
89+
var responseMap map[string]any
9090
marshalErr := json.Unmarshal(test.body, &responseMap)
9191
if marshalErr != nil {
9292
t.Error(marshalErr, "Error unmarshaling request body")

pkg/epp/handlers/server.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ type RequestContext struct {
112112

113113
type Request struct {
114114
Headers map[string]string
115-
Body map[string]interface{}
115+
Body map[string]any
116116
Metadata map[string]any
117117
}
118118
type Response struct {
@@ -143,7 +143,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
143143
RequestState: RequestReceived,
144144
Request: &Request{
145145
Headers: make(map[string]string),
146-
Body: make(map[string]interface{}),
146+
Body: make(map[string]any),
147147
Metadata: make(map[string]any),
148148
},
149149
Response: &Response{
@@ -152,7 +152,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
152152
}
153153

154154
var body []byte
155-
var responseBody map[string]interface{}
155+
var responseBody map[string]any
156156

157157
// Create error handling var as each request should only report once for
158158
// error metrics. This doesn't cover the error "Cannot receive stream request" because
@@ -308,7 +308,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
308308
// Handle the err and fire an immediate response.
309309
if err != nil {
310310
logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req)
311-
resp, err := BuildErrResponse(err)
311+
resp, err := buildErrResponse(err)
312312
if err != nil {
313313
return err
314314
}
@@ -389,7 +389,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
389389
return nil
390390
}
391391

392-
func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
392+
func buildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
393393
var resp *extProcPb.ProcessingResponse
394394

395395
switch errutil.CanonicalCode(err) {
@@ -416,6 +416,17 @@ func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) {
416416
},
417417
},
418418
}
419+
// This code can be returned by the director when there are no candidate pods for the request scheduling.
420+
case errutil.ServiceUnavailable:
421+
resp = &extProcPb.ProcessingResponse{
422+
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
423+
ImmediateResponse: &extProcPb.ImmediateResponse{
424+
Status: &envoyTypePb.HttpStatus{
425+
Code: envoyTypePb.StatusCode_ServiceUnavailable,
426+
},
427+
},
428+
},
429+
}
419430
// This code can be returned when users provide invalid json request.
420431
case errutil.BadRequest:
421432
resp = &extProcPb.ProcessingResponse{

pkg/epp/requestcontrol/director.go

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@ import (
2424
"math/rand"
2525
"net"
2626
"strconv"
27+
"strings"
2728
"time"
2829

2930
"github.com/go-logr/logr"
3031
"sigs.k8s.io/controller-runtime/pkg/log"
3132
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
3233
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
34+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3335
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
3436
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
3537
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
@@ -39,6 +41,11 @@ import (
3941
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
4042
)
4143

44+
const (
45+
subsetHintNamespace = "envoy.lb.subset_hint"
46+
subsetHintKey = "x-gateway-destination-endpoint-subset"
47+
)
48+
4249
// Scheduler defines the interface required by the Director for scheduling.
4350
type Scheduler interface {
4451
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error)
@@ -118,12 +125,12 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
118125
}
119126

120127
// Prepare LLMRequest (needed for both saturation detection and Scheduler)
121-
reqCtx.SchedulingRequest = schedulingtypes.NewLLMRequest(
122-
reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
123-
reqCtx.ResolvedTargetModel,
124-
prompt,
125-
reqCtx.Request.Headers,
126-
reqCtx.Request.Metadata)
128+
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
129+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
130+
TargetModel: reqCtx.ResolvedTargetModel,
131+
Prompt: prompt,
132+
Headers: reqCtx.Request.Headers,
133+
}
127134

128135
logger = logger.WithValues("model", reqCtx.Model, "resolvedTargetModel", reqCtx.ResolvedTargetModel, "criticality", requestCriticality)
129136

@@ -135,11 +142,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
135142
return reqCtx, err
136143
}
137144

138-
// --- 3. Call Scheduler ---
139-
// Snapshot pod metrics from the datastore to:
140-
// 1. Reduce concurrent access to the datastore.
141-
// 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles.
142-
candidatePods := schedulingtypes.ToSchedulerPodMetrics(d.datastore.PodGetAll())
145+
// --- 3. Call Scheduler (with the relevant candidate pods) ---
146+
candidatePods := d.getCandidatePodsForScheduling(ctx, reqCtx.Request.Metadata)
147+
if len(candidatePods) == 0 {
148+
return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"}
149+
}
143150
results, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, candidatePods)
144151
if err != nil {
145152
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
@@ -177,6 +184,52 @@ func (d *Director) admitRequest(ctx context.Context, requestCriticality v1alpha2
177184
return nil
178185
}
179186

187+
// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
188+
// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
189+
// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.
190+
// Snapshot pod metrics from the datastore to:
191+
// 1. Reduce concurrent access to the datastore.
192+
// 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles.
193+
func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMetadata map[string]any) []schedulingtypes.Pod {
194+
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
195+
196+
subsetMap, found := requestMetadata[subsetHintNamespace].(map[string]any)
197+
if !found {
198+
return schedulingtypes.ToSchedulerPodMetrics(d.datastore.PodGetAll())
199+
}
200+
201+
// Check if endpoint key is present in the subset map and ensure there is at least one value
202+
endpointSubsetList, found := subsetMap[subsetHintKey].([]any)
203+
if !found {
204+
return schedulingtypes.ToSchedulerPodMetrics(d.datastore.PodGetAll())
205+
} else if len(endpointSubsetList) == 0 {
206+
loggerTrace.Info("found empty subset filter in request metadata, filtering all pods")
207+
return []schedulingtypes.Pod{}
208+
}
209+
210+
// Create a map of endpoint addresses for easy lookup
211+
endpoints := make(map[string]bool)
212+
for _, endpoint := range endpointSubsetList {
213+
// Extract address from endpoint
214+
// The endpoint is formatted as "<address>:<port>" (ex. "10.0.1.0:8080")
215+
epStr := strings.Split(endpoint.(string), ":")[0]
216+
endpoints[epStr] = true
217+
}
218+
219+
podTotalCount := 0
220+
podFitleredList := d.datastore.PodList(func(pm backendmetrics.PodMetrics) bool {
221+
podTotalCount++
222+
if _, found := endpoints[pm.GetPod().Address]; found {
223+
return true
224+
}
225+
return false
226+
})
227+
228+
loggerTrace.Info("filtered candidate pods by subset filtering", "podTotalCount", podTotalCount, "filteredCount", len(podFitleredList))
229+
230+
return schedulingtypes.ToSchedulerPodMetrics(podFitleredList)
231+
}
232+
180233
// prepareRequest populates the RequestContext and calls the registered PreRequest plugins
181234
// for allowing plugging customized logic based on the scheduling results.
182235
func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestContext, result *schedulingtypes.SchedulingResult) (*handlers.RequestContext, error) {

0 commit comments

Comments
 (0)