Skip to content

Commit 8485ba6

Browse files
committed
implments model rewrite and traffic splitting.
1 parent a3b4528 commit 8485ba6

File tree

2 files changed

+401
-20
lines changed

2 files changed

+401
-20
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"fmt"
2424
"math/rand"
2525
"net"
26+
"sort"
2627
"strings"
2728
"time"
2829

@@ -52,6 +53,7 @@ type Datastore interface {
5253
PoolGet() (*v1.InferencePool, error)
5354
ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective
5455
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
56+
RewriteGetAll() []*v1alpha2.InferenceModelRewrite
5557
}
5658

5759
// Scheduler defines the interface required by the Director for scheduling.
@@ -133,13 +135,23 @@ func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handler
133135
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
134136
logger := log.FromContext(ctx)
135137

136-
// Resolve target model and update req context.
137-
reqCtx, err := d.resolveTargetModel(reqCtx)
138-
if err != nil {
139-
return reqCtx, err
138+
// Parse Request, Resolve Target Models, and Determine Parameters
139+
requestBodyMap := reqCtx.Request.Body
140+
var ok bool
141+
reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string)
142+
143+
if !ok {
144+
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"}
140145
}
146+
if reqCtx.TargetModelName == "" {
147+
// Default to incoming model name
148+
reqCtx.TargetModelName = reqCtx.IncomingModelName
149+
}
150+
151+
d.applyWeightedModelRewrite(reqCtx)
152+
153+
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
141154

142-
// Parse request body.
143155
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
144156
if err != nil {
145157
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
@@ -200,6 +212,56 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
200212
return reqCtx, nil
201213
}
202214

215+
func (d *Director) applyWeightedModelRewrite(reqCtx *handlers.RequestContext) {
216+
rewrites := d.datastore.RewriteGetAll()
217+
if len(rewrites) == 0 {
218+
return
219+
}
220+
221+
sort.Slice(rewrites, func(i, j int) bool {
222+
return rewrites[i].CreationTimestamp.Before(&rewrites[j].CreationTimestamp)
223+
})
224+
225+
for _, rewrite := range rewrites {
226+
for _, rule := range rewrite.Spec.Rules {
227+
for _, match := range rule.Matches {
228+
if match.Model != nil && match.Model.Value == reqCtx.IncomingModelName {
229+
reqCtx.TargetModelName = d.selectWeightedModel(rule.Targets)
230+
return
231+
}
232+
}
233+
}
234+
}
235+
}
236+
237+
func (d *Director) selectWeightedModel(models []v1alpha2.TargetModel) string {
238+
if len(models) == 0 {
239+
return ""
240+
}
241+
242+
var totalWeight int32
243+
for _, model := range models {
244+
totalWeight += model.Weight
245+
}
246+
247+
if totalWeight == 0 {
248+
// If total weight is 0, distribute evenly
249+
return models[rand.Intn(len(models))].ModelRewrite
250+
}
251+
252+
randomNum := rand.Intn(int(totalWeight))
253+
var currentWeight int32
254+
for _, model := range models {
255+
currentWeight += model.Weight
256+
if randomNum < int(currentWeight) {
257+
return model.ModelRewrite
258+
}
259+
}
260+
261+
// Should not happen
262+
return models[len(models)-1].ModelRewrite
263+
}
264+
203265
// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
204266
// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
205267
// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.

0 commit comments

Comments
 (0)