Skip to content

Commit 5441be2

Browse files
committed
implments model rewrite and traffic splitting.
1 parent 86de75e commit 5441be2

File tree

4 files changed

+393
-28
lines changed

4 files changed

+393
-28
lines changed

pkg/epp/controller/inferencemodelrewrite_reconciler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,5 @@ func (c *InferenceModelRewriteReconciler) SetupWithManager(ctx context.Context,
8383
}
8484

8585
func (c *InferenceModelRewriteReconciler) eventPredicate(infModelRewrite *v1alpha2.InferenceModelRewrite) bool {
86-
return string(infModelRewrite.Spec.PoolRef.Name) == c.PoolGKNN.Name && string(infModelRewrite.Spec.PoolRef.Group) == c.PoolGKNN.Group
86+
return infModelRewrite.Spec.PoolRef != nil && string(infModelRewrite.Spec.PoolRef.Name) == c.PoolGKNN.Name && string(infModelRewrite.Spec.PoolRef.Group) == c.PoolGKNN.Group
8787
}

pkg/epp/controller/inferencemodelrewrite_reconciler_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ func (b *inferenceModelRewriteBuilder) Namespace(ns string) *inferenceModelRewri
9090
}
9191

9292
func (b *inferenceModelRewriteBuilder) PoolName(name string) *inferenceModelRewriteBuilder {
93+
b.Spec.PoolRef = &v1alpha2.PoolObjectReference{}
9394
b.Spec.PoolRef.Name = v1alpha2.ObjectName(name)
9495
return b
9596
}

pkg/epp/requestcontrol/director.go

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"fmt"
2424
"math/rand"
2525
"net"
26+
"sort"
2627
"strings"
2728
"time"
2829

@@ -50,6 +51,7 @@ type Datastore interface {
5051
PoolGet() (*datalayer.EndpointPool, error)
5152
ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective
5253
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
54+
RewriteGetAll() []*v1alpha2.InferenceModelRewrite
5355
}
5456

5557
// Scheduler defines the interface required by the Director for scheduling.
@@ -110,34 +112,28 @@ func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.R
110112
return infObjective
111113
}
112114

113-
// resolveTargetModel is a helper to update reqCtx with target model based on request.
114-
func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
115+
// HandleRequest orchestrates the request lifecycle.
116+
// It always returns the requestContext even in the error case, as the request context is used in error handling.
117+
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
118+
logger := log.FromContext(ctx)
119+
120+
// Parse Request, Resolve Target Models, and Determine Parameters
115121
requestBodyMap := reqCtx.Request.Body
116122
var ok bool
117123
reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string)
124+
118125
if !ok {
119126
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"}
120127
}
121128
if reqCtx.TargetModelName == "" {
122129
// Default to incoming model name
123130
reqCtx.TargetModelName = reqCtx.IncomingModelName
124131
}
125-
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
126-
return reqCtx, nil
127-
}
128132

129-
// HandleRequest orchestrates the request lifecycle.
130-
// It always returns the requestContext even in the error case, as the request context is used in error handling.
131-
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
132-
logger := log.FromContext(ctx)
133+
d.applyWeightedModelRewrite(reqCtx)
133134

134-
// Resolve target model and update req context.
135-
reqCtx, err := d.resolveTargetModel(reqCtx)
136-
if err != nil {
137-
return reqCtx, err
138-
}
135+
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
139136

140-
// Parse request body.
141137
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
142138
if err != nil {
143139
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
@@ -198,6 +194,56 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
198194
return reqCtx, nil
199195
}
200196

197+
func (d *Director) applyWeightedModelRewrite(reqCtx *handlers.RequestContext) {
198+
rewrites := d.datastore.RewriteGetAll()
199+
if len(rewrites) == 0 {
200+
return
201+
}
202+
203+
sort.Slice(rewrites, func(i, j int) bool {
204+
return rewrites[i].CreationTimestamp.Before(&rewrites[j].CreationTimestamp)
205+
})
206+
207+
for _, rewrite := range rewrites {
208+
for _, rule := range rewrite.Spec.Rules {
209+
for _, match := range rule.Matches {
210+
if match.Model != nil && match.Model.Value == reqCtx.IncomingModelName {
211+
reqCtx.TargetModelName = d.selectWeightedModel(rule.Targets)
212+
return
213+
}
214+
}
215+
}
216+
}
217+
}
218+
219+
func (d *Director) selectWeightedModel(models []v1alpha2.TargetModel) string {
220+
if len(models) == 0 {
221+
return ""
222+
}
223+
224+
var totalWeight int32
225+
for _, model := range models {
226+
totalWeight += model.Weight
227+
}
228+
229+
if totalWeight == 0 {
230+
// If total weight is 0, distribute evenly
231+
return models[rand.Intn(len(models))].ModelRewrite
232+
}
233+
234+
randomNum := rand.Intn(int(totalWeight))
235+
var currentWeight int32
236+
for _, model := range models {
237+
currentWeight += model.Weight
238+
if randomNum < int(currentWeight) {
239+
return model.ModelRewrite
240+
}
241+
}
242+
243+
// Should not happen
244+
return models[len(models)-1].ModelRewrite
245+
}
246+
201247
// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
202248
// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
203249
// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.

0 commit comments

Comments
 (0)