@@ -23,6 +23,7 @@ import (
2323 "fmt"
2424 "math/rand"
2525 "net"
26+ "sort"
2627 "strings"
2728 "time"
2829
@@ -52,6 +53,7 @@ type Datastore interface {
5253 PoolGet () (* v1.InferencePool , error )
5354 ObjectiveGet (objectiveName string ) * v1alpha2.InferenceObjective
5455 PodList (predicate func (backendmetrics.PodMetrics ) bool ) []backendmetrics.PodMetrics
56+ RewriteGetAll () []* v1alpha2.InferenceModelRewrite
5557}
5658
5759// Scheduler defines the interface required by the Director for scheduling.
@@ -133,13 +135,23 @@ func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handler
133135func (d * Director ) HandleRequest (ctx context.Context , reqCtx * handlers.RequestContext ) (* handlers.RequestContext , error ) {
134136 logger := log .FromContext (ctx )
135137
136- // Resolve target model and update req context.
137- reqCtx , err := d .resolveTargetModel (reqCtx )
138- if err != nil {
139- return reqCtx , err
138+ // Parse Request, Resolve Target Models, and Determine Parameters
139+ requestBodyMap := reqCtx .Request .Body
140+ var ok bool
141+ reqCtx .IncomingModelName , ok = requestBodyMap ["model" ].(string )
142+
143+ if ! ok {
144+ return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request body" }
140145 }
146+ if reqCtx .TargetModelName == "" {
147+ // Default to incoming model name
148+ reqCtx .TargetModelName = reqCtx .IncomingModelName
149+ }
150+
151+ d .applyWeightedModelRewrite (reqCtx )
152+
153+ reqCtx .Request .Body ["model" ] = reqCtx .TargetModelName
141154
142- // Parse request body.
143155 requestBody , err := requtil .ExtractRequestBody (reqCtx .Request .Body )
144156 if err != nil {
145157 return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : fmt .Errorf ("failed to extract request data: %w" , err ).Error ()}
@@ -200,6 +212,56 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
200212 return reqCtx , nil
201213}
202214
215+ func (d * Director ) applyWeightedModelRewrite (reqCtx * handlers.RequestContext ) {
216+ rewrites := d .datastore .RewriteGetAll ()
217+ if len (rewrites ) == 0 {
218+ return
219+ }
220+
221+ sort .Slice (rewrites , func (i , j int ) bool {
222+ return rewrites [i ].CreationTimestamp .Before (& rewrites [j ].CreationTimestamp )
223+ })
224+
225+ for _ , rewrite := range rewrites {
226+ for _ , rule := range rewrite .Spec .Rules {
227+ for _ , match := range rule .Matches {
228+ if match .Model != nil && match .Model .Value == reqCtx .IncomingModelName {
229+ reqCtx .TargetModelName = d .selectWeightedModel (rule .Targets )
230+ return
231+ }
232+ }
233+ }
234+ }
235+ }
236+
237+ func (d * Director ) selectWeightedModel (models []v1alpha2.TargetModel ) string {
238+ if len (models ) == 0 {
239+ return ""
240+ }
241+
242+ var totalWeight int32
243+ for _ , model := range models {
244+ totalWeight += model .Weight
245+ }
246+
247+ if totalWeight == 0 {
248+ // If total weight is 0, distribute evenly
249+ return models [rand .Intn (len (models ))].ModelRewrite
250+ }
251+
252+ randomNum := rand .Intn (int (totalWeight ))
253+ var currentWeight int32
254+ for _ , model := range models {
255+ currentWeight += model .Weight
256+ if randomNum < int (currentWeight ) {
257+ return model .ModelRewrite
258+ }
259+ }
260+
261+ // Should not happen
262+ return models [len (models )- 1 ].ModelRewrite
263+ }
264+
203265// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
204266// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
205267// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.
0 commit comments