@@ -23,6 +23,7 @@ import (
2323 "fmt"
2424 "math/rand"
2525 "net"
26+ "sort"
2627 "strings"
2728 "time"
2829
@@ -50,6 +51,7 @@ type Datastore interface {
5051 PoolGet () (* datalayer.EndpointPool , error )
5152 ObjectiveGet (objectiveName string ) * v1alpha2.InferenceObjective
5253 PodList (predicate func (backendmetrics.PodMetrics ) bool ) []backendmetrics.PodMetrics
54+ RewriteGetAll () []* v1alpha2.InferenceModelRewrite
5355}
5456
5557// Scheduler defines the interface required by the Director for scheduling.
@@ -110,34 +112,28 @@ func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.R
110112 return infObjective
111113}
112114
113- // resolveTargetModel is a helper to update reqCtx with target model based on request.
114- func (d * Director ) resolveTargetModel (reqCtx * handlers.RequestContext ) (* handlers.RequestContext , error ) {
115+ // HandleRequest orchestrates the request lifecycle.
116+ // It always returns the requestContext even in the error case, as the request context is used in error handling.
117+ func (d * Director ) HandleRequest (ctx context.Context , reqCtx * handlers.RequestContext ) (* handlers.RequestContext , error ) {
118+ logger := log .FromContext (ctx )
119+
120+ // Parse Request, Resolve Target Models, and Determine Parameters
115121 requestBodyMap := reqCtx .Request .Body
116122 var ok bool
117123 reqCtx .IncomingModelName , ok = requestBodyMap ["model" ].(string )
124+
118125 if ! ok {
119126 return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request body" }
120127 }
121128 if reqCtx .TargetModelName == "" {
122129 // Default to incoming model name
123130 reqCtx .TargetModelName = reqCtx .IncomingModelName
124131 }
125- reqCtx .Request .Body ["model" ] = reqCtx .TargetModelName
126- return reqCtx , nil
127- }
128132
129- // HandleRequest orchestrates the request lifecycle.
130- // It always returns the requestContext even in the error case, as the request context is used in error handling.
131- func (d * Director ) HandleRequest (ctx context.Context , reqCtx * handlers.RequestContext ) (* handlers.RequestContext , error ) {
132- logger := log .FromContext (ctx )
133+ d .applyWeightedModelRewrite (reqCtx )
133134
134- // Resolve target model and update req context.
135- reqCtx , err := d .resolveTargetModel (reqCtx )
136- if err != nil {
137- return reqCtx , err
138- }
135+ reqCtx .Request .Body ["model" ] = reqCtx .TargetModelName
139136
140- // Parse request body.
141137 requestBody , err := requtil .ExtractRequestBody (reqCtx .Request .Body )
142138 if err != nil {
143139 return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : fmt .Errorf ("failed to extract request data: %w" , err ).Error ()}
@@ -198,6 +194,56 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
198194 return reqCtx , nil
199195}
200196
197+ func (d * Director ) applyWeightedModelRewrite (reqCtx * handlers.RequestContext ) {
198+ rewrites := d .datastore .RewriteGetAll ()
199+ if len (rewrites ) == 0 {
200+ return
201+ }
202+
203+ sort .Slice (rewrites , func (i , j int ) bool {
204+ return rewrites [i ].CreationTimestamp .Before (& rewrites [j ].CreationTimestamp )
205+ })
206+
207+ for _ , rewrite := range rewrites {
208+ for _ , rule := range rewrite .Spec .Rules {
209+ for _ , match := range rule .Matches {
210+ if match .Model != nil && match .Model .Value == reqCtx .IncomingModelName {
211+ reqCtx .TargetModelName = d .selectWeightedModel (rule .Targets )
212+ return
213+ }
214+ }
215+ }
216+ }
217+ }
218+
219+ func (d * Director ) selectWeightedModel (models []v1alpha2.TargetModel ) string {
220+ if len (models ) == 0 {
221+ return ""
222+ }
223+
224+ var totalWeight int32
225+ for _ , model := range models {
226+ totalWeight += model .Weight
227+ }
228+
229+ if totalWeight == 0 {
230+ // If total weight is 0, distribute evenly
231+ return models [rand .Intn (len (models ))].ModelRewrite
232+ }
233+
234+ randomNum := rand .Intn (int (totalWeight ))
235+ var currentWeight int32
236+ for _ , model := range models {
237+ currentWeight += model .Weight
238+ if randomNum < int (currentWeight ) {
239+ return model .ModelRewrite
240+ }
241+ }
242+
243+ // Should not happen
244+ return models [len (models )- 1 ].ModelRewrite
245+ }
246+
201247// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
202248// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
203249// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.
0 commit comments