diff --git a/apix/v1alpha2/inferencemodelrewrite_types.go b/apix/v1alpha2/inferencemodelrewrite_types.go index ef68d8366..262238c28 100644 --- a/apix/v1alpha2/inferencemodelrewrite_types.go +++ b/apix/v1alpha2/inferencemodelrewrite_types.go @@ -57,20 +57,25 @@ type InferenceModelRewriteSpec struct { // If multiple InferenceModelRewrite resources target the same // InferencePool, the controller will merge them based on precedence. // - // **Timestamp Wins:** If two rules from different rewrites all matches, - // the rule from the *oldest* - // InferenceModelRewrite resource (determined by - // metadata.creationTimestamp) will be used. + // Across all rules specified on applicable rewrites, precedence MUST be + // given to the match having an "Exact" model match over a generic match + // (a rule with an empty `matches` array). + // + // If ties still exist across multiple InferenceModelRewrite resources (e.g. + // two rewrites both have an exact match for the same model), matching + // precedence MUST be determined by the oldest resource based on + // creation timestamp. + // + // If ties still exist within a single InferenceModelRewrite resource, the + // FIRST matching rule (in list order) is used. // +required Rules []InferenceModelRewriteRule `json:"rules"` } // InferenceModelRewriteRule defines the match criteria and corresponding action. -// -// A specific model name can only be matched by one rule across all -// rules attached to the same InferencePool. If multiple rules attempt -// to match the same model name, the oldest rule (by creationTimestamp) -// will be the only one considered valid. +// For details on how precedence is determined across multiple rules and +// InferenceModelRewrite resources, see the "Precedence and Conflict Resolution" +// section in InferenceModelRewriteSpec. type InferenceModelRewriteRule struct { // Matches defines the criteria for matching a request. // If multiple match criteria are specified, a request matches if diff --git a/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml b/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml index bb9b3e6cf..2680ea091 100644 --- a/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml +++ b/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml @@ -74,11 +74,9 @@ spec: items: description: |- InferenceModelRewriteRule defines the match criteria and corresponding action. - - A specific model name can only be matched by one rule across all - rules attached to the same InferencePool. If multiple rules attempt - to match the same model name, the oldest rule (by creationTimestamp) - will be the only one considered valid. + For details on how precedence is determined across multiple rules and + InferenceModelRewrite resources, see the "Precedence and Conflict Resolution" + section in InferenceModelRewriteSpec. properties: matches: items: diff --git a/docs/proposals/1816-inferenceomodelrewrite/README.md b/docs/proposals/1816-inferenceomodelrewrite/README.md index 9f20b36fe..93cfd6d4c 100644 --- a/docs/proposals/1816-inferenceomodelrewrite/README.md +++ b/docs/proposals/1816-inferenceomodelrewrite/README.md @@ -64,20 +64,25 @@ type InferenceModelRewriteSpec struct { // If multiple InferenceModelRewrite resources target the same // InferencePool, the controller will merge them based on precedence. // - // **Timestamp Wins:** If two rules from different rewrite all matches, - // the rule from the *oldest* - // InferenceModelRewrite resource (determined by - // metadata.creationTimestamp) will be used. + // Across all rules specified on applicable rewrites, precedence MUST be + // given to the match having an "Exact" model match over a generic match + // (a rule with an empty `matches` array). + // + // If ties still exist across multiple InferenceModelRewrite resources (e.g. + // two rewrites both have an exact match for the same model), matching + // precedence MUST be determined by the oldest resource based on + // creation timestamp. + // + // If ties still exist within a single InferenceModelRewrite resource, the + // FIRST matching rule (in list order) is used. // +required Rules []InferenceModelRewriteRule `json:"rules"` } // InferenceModelRewriteRule defines the match criteria and corresponding action. -// -// A specific model name can only be matched by one rule across all -// rewrites attached to the same InferencePool. If multiple rules attempt -// to match the same model name, the oldest rule (by creationTimestamp) -// will be the only one considered valid. +// For details on how precedence is determined across multiple rules and +// InferenceModelRewrite resources, see the "Precedence and Conflict Resolution" +// section in InferenceModelRewriteSpec. type InferenceModelRewriteRule struct { // Matches defines the criteria for matching a request. // If multiple match criteria are specified, a request matches if diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler.go b/pkg/epp/controller/inferencemodelrewrite_reconciler.go new file mode 100644 index 000000000..611147a83 --- /dev/null +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler.go @@ -0,0 +1,87 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + "fmt" + + "k8s.io/apimachinery/pkg/api/errors" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/predicate" + + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/common" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +type InferenceModelRewriteReconciler struct { + client.Reader + Datastore datastore.Datastore + PoolGKNN common.GKNN +} + +func (c *InferenceModelRewriteReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + logger := log.FromContext(ctx).V(logutil.DEFAULT) + ctx = ctrl.LoggerInto(ctx, logger) + + logger.Info("Reconciling InferenceModelRewrite") + + infModelRewrite := &v1alpha2.InferenceModelRewrite{} + notFound := false + if err := c.Get(ctx, req.NamespacedName, infModelRewrite); err != nil { + if !errors.IsNotFound(err) { + return ctrl.Result{}, fmt.Errorf("unable to get InferenceModelRewrite - %w", err) + } + notFound = true + } + + if notFound || !infModelRewrite.DeletionTimestamp.IsZero() || infModelRewrite.Spec.PoolRef == nil || infModelRewrite.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) || infModelRewrite.Spec.PoolRef.Group != v1alpha2.Group(c.PoolGKNN.Group) { + // InferenceModelRewrite object got deleted or changed the referenced pool. + c.Datastore.RewriteDelete(req.NamespacedName) + return ctrl.Result{}, nil + } + + // Add or update if the InferenceModelRewrite instance has a creation timestamp older than the existing entry of the model. + logger = logger.WithValues("poolRef", infModelRewrite.Spec.PoolRef) + c.Datastore.RewriteSet(infModelRewrite) + logger.Info("Added/Updated InferenceModelRewrite") + + return ctrl.Result{}, nil +} + +func (c *InferenceModelRewriteReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&v1alpha2.InferenceModelRewrite{}). + WithEventFilter(predicate.Funcs{ + CreateFunc: func(e event.CreateEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) }, + UpdateFunc: func(e event.UpdateEvent) bool { + return c.eventPredicate(e.ObjectOld.(*v1alpha2.InferenceModelRewrite)) || c.eventPredicate(e.ObjectNew.(*v1alpha2.InferenceModelRewrite)) + }, + DeleteFunc: func(e event.DeleteEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) }, + GenericFunc: func(e event.GenericEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) }, + }). + Complete(c) +} + +func (c *InferenceModelRewriteReconciler) eventPredicate(infModelRewrite *v1alpha2.InferenceModelRewrite) bool { + return infModelRewrite.Spec.PoolRef != nil && string(infModelRewrite.Spec.PoolRef.Name) == c.PoolGKNN.Name && string(infModelRewrite.Spec.PoolRef.Group) == c.PoolGKNN.Group +} diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go new file mode 100644 index 000000000..8365aff0b --- /dev/null +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go @@ -0,0 +1,216 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/common" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" +) + +var ( + poolForRewrite = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef() + rewrite1 = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite1", + Namespace: poolForRewrite.Namespace, + CreationTimestamp: metav1.Unix(1000, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + }, + } + rewrite1Pool2 = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: rewrite1.Name, + Namespace: rewrite1.Namespace, + CreationTimestamp: metav1.Unix(1001, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: "test-pool2"}, + }, + } + rewrite1Updated = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: rewrite1.Name, + Namespace: rewrite1.Namespace, + CreationTimestamp: metav1.Unix(1003, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + Rules: []v1alpha2.InferenceModelRewriteRule{{}}, + }, + } + rewrite1Deleted = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: rewrite1.Name, + Namespace: rewrite1.Namespace, + CreationTimestamp: metav1.Unix(1004, 0), + DeletionTimestamp: &metav1.Time{Time: time.Now()}, + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + }, + } + rewrite2 = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite2", + Namespace: poolForRewrite.Namespace, + CreationTimestamp: metav1.Unix(1001, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + }, + } +) + +func TestInferenceModelRewriteReconciler(t *testing.T) { + tests := []struct { + name string + rewritesInStore []*v1alpha2.InferenceModelRewrite + rewritesInAPIServer []*v1alpha2.InferenceModelRewrite + rewrite *v1alpha2.InferenceModelRewrite + incomingReq *types.NamespacedName + wantRewrites []*v1alpha2.InferenceModelRewrite + wantResult ctrl.Result + }{ + { + name: "Empty store, add new rewrite", + rewrite: rewrite1, + wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1}, + }, + { + name: "Existing rewrite changed pools", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + rewrite: rewrite1Pool2, + wantRewrites: []*v1alpha2.InferenceModelRewrite{}, + }, + { + name: "Not found, delete existing rewrite", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + incomingReq: &types.NamespacedName{Name: rewrite1.Name, Namespace: rewrite1.Namespace}, + wantRewrites: []*v1alpha2.InferenceModelRewrite{}, + }, + { + name: "Deletion timestamp set, delete existing rewrite", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + rewrite: rewrite1Deleted, + incomingReq: &types.NamespacedName{Name: rewrite1Deleted.Name, Namespace: rewrite1Deleted.Namespace}, + wantRewrites: []*v1alpha2.InferenceModelRewrite{}, + }, + { + name: "Rewrite updated", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + rewrite: rewrite1Updated, + wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1Updated}, + }, + { + name: "Rewrite not found, no matching existing rewrite to delete", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + incomingReq: &types.NamespacedName{Name: "non-existent-rewrite", Namespace: poolForRewrite.Namespace}, + wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1}, + }, + { + name: "Add to existing", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + rewrite: rewrite2, + wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1, rewrite2}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = v1alpha2.Install(scheme) + _ = v1.Install(scheme) + initObjs := []client.Object{} + if test.rewrite != nil && test.rewrite.DeletionTimestamp.IsZero() { + initObjs = append(initObjs, test.rewrite) + } + for _, r := range test.rewritesInAPIServer { + initObjs = append(initObjs, r) + } + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(initObjs...). + Build() + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := datastore.NewDatastore(t.Context(), pmf, 0) + for _, r := range test.rewritesInStore { + ds.RewriteSet(r) + } + _ = ds.PoolSet(context.Background(), fakeClient, poolForRewrite) + reconciler := &InferenceModelRewriteReconciler{ + Reader: fakeClient, + Datastore: ds, + PoolGKNN: common.GKNN{ + NamespacedName: types.NamespacedName{Name: poolForRewrite.Name, Namespace: poolForRewrite.Namespace}, + GroupKind: schema.GroupKind{Group: poolForRewrite.GroupVersionKind().Group, Kind: poolForRewrite.GroupVersionKind().Kind}, + }, + } + if test.incomingReq == nil { + test.incomingReq = &types.NamespacedName{Name: test.rewrite.Name, Namespace: test.rewrite.Namespace} + } + + result, err := reconciler.Reconcile(context.Background(), ctrl.Request{NamespacedName: *test.incomingReq}) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if diff := cmp.Diff(result, test.wantResult); diff != "" { + t.Errorf("Unexpected result diff (+got/-want): %s", diff) + } + + if len(test.wantRewrites) != len(ds.RewriteGetAll()) { + t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.RewriteGetAll())) + } + + if diff := diffStoreRewrites(ds, test.wantRewrites); diff != "" { + t.Errorf("Unexpected diff (+got/-want): %s", diff) + } + }) + } +} + +func diffStoreRewrites(ds datastore.Datastore, wantRewrites []*v1alpha2.InferenceModelRewrite) string { + if wantRewrites == nil { + wantRewrites = []*v1alpha2.InferenceModelRewrite{} + } + + gotRewrites := ds.RewriteGetAll() + if diff := cmp.Diff(wantRewrites, gotRewrites); diff != "" { + return "rewrites:" + diff + } + return "" +} diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 2ab2e98cb..8b5642a6d 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -59,6 +59,12 @@ type Datastore interface { ObjectiveDelete(namespacedName types.NamespacedName) ObjectiveGetAll() []*v1alpha2.InferenceObjective + // InferenceModelRewrite operations + RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) + RewriteDelete(namespacedName types.NamespacedName) + RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule + RewriteGetAll() []*v1alpha2.InferenceModelRewrite + // PodList lists pods matching the given predicate. PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool @@ -72,9 +78,10 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory // Initialize with defaults store := &datastore{ parentCtx: parentCtx, - poolAndObjectivesMu: sync.RWMutex{}, pool: nil, + mu: sync.RWMutex{}, objectives: make(map[string]*v1alpha2.InferenceObjective), + rewrites: NewModelRewriteStore(), pods: &sync.Map{}, modelServerMetricsPort: modelServerMetricsPort, epf: epFactory, @@ -91,11 +98,13 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory type datastore struct { // parentCtx controls the lifecycle of the background metrics goroutines that spawn up by the datastore. parentCtx context.Context - // poolAndObjectivesMu is used to synchronize access to pool and the objectives map. - poolAndObjectivesMu sync.RWMutex - pool *datalayer.EndpointPool - // key: InferenceObjective.Spec.ModelName, value: *InferenceObjective + // mu is used to synchronize access to pool, objectives, and rewrites. + mu sync.RWMutex + pool *datalayer.EndpointPool + // key: InferenceObjective name, value: *InferenceObjective objectives map[string]*v1alpha2.InferenceObjective + // rewrites store for InferenceModelRewrite objects. + rewrites *ModelRewriteStore // key: types.NamespacedName, value: backendmetrics.PodMetrics pods *sync.Map // modelServerMetricsPort metrics port from EPP command line argument @@ -105,10 +114,11 @@ type datastore struct { } func (ds *datastore) Clear() { - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() + ds.mu.Lock() + defer ds.mu.Unlock() ds.pool = nil ds.objectives = make(map[string]*v1alpha2.InferenceObjective) + ds.rewrites = NewModelRewriteStore() // stop all pods go routines before clearing the pods map. ds.pods.Range(func(_, v any) bool { ds.epf.ReleaseEndpoint(v.(backendmetrics.PodMetrics)) @@ -124,8 +134,8 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, endpoint return nil } logger := log.FromContext(ctx) - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() + ds.mu.Lock() + defer ds.mu.Unlock() oldEndpointPool := ds.pool ds.pool = endpointPool @@ -146,8 +156,8 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, endpoint } func (ds *datastore) PoolGet() (*datalayer.EndpointPool, error) { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() + ds.mu.RLock() + defer ds.mu.RUnlock() if !ds.PoolHasSynced() { return nil, errPoolNotSynced } @@ -155,14 +165,14 @@ func (ds *datastore) PoolGet() (*datalayer.EndpointPool, error) { } func (ds *datastore) PoolHasSynced() bool { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() + ds.mu.RLock() + defer ds.mu.RUnlock() return ds.pool != nil } func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() + ds.mu.RLock() + defer ds.mu.RUnlock() if ds.pool == nil { return false } @@ -171,39 +181,59 @@ func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { return poolSelector.Matches(podSet) } +// /// InferenceObjective APIs /// func (ds *datastore) ObjectiveSet(infObjective *v1alpha2.InferenceObjective) { - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() - // Set the objective. + ds.mu.Lock() + defer ds.mu.Unlock() ds.objectives[infObjective.Name] = infObjective } func (ds *datastore) ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() - iObj, ok := ds.objectives[objectiveName] - if !ok { - return nil - } - return iObj + ds.mu.RLock() + defer ds.mu.RUnlock() + return ds.objectives[objectiveName] } func (ds *datastore) ObjectiveDelete(namespacedName types.NamespacedName) { - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() + ds.mu.Lock() + defer ds.mu.Unlock() delete(ds.objectives, namespacedName.Name) } func (ds *datastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() - res := []*v1alpha2.InferenceObjective{} + ds.mu.RLock() + defer ds.mu.RUnlock() + res := make([]*v1alpha2.InferenceObjective, 0, len(ds.objectives)) for _, v := range ds.objectives { res = append(res, v) } return res } +func (ds *datastore) RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) { + ds.mu.Lock() + defer ds.mu.Unlock() + ds.rewrites.Set(infModelRewrite) +} + +func (ds *datastore) RewriteDelete(namespacedName types.NamespacedName) { + ds.mu.Lock() + defer ds.mu.Unlock() + ds.rewrites.Delete(namespacedName) +} + +func (ds *datastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { + ds.mu.RLock() + defer ds.mu.RUnlock() + return ds.rewrites.GetRule(modelName) +} + +func (ds *datastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite { + ds.mu.RLock() + defer ds.mu.RUnlock() + return ds.rewrites.GetAll() +} + // /// Pods/endpoints APIs /// // TODO: add a flag for callers to specify the staleness threshold for metrics. // ref: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/1046#discussion_r2246351694 diff --git a/pkg/epp/datastore/modelrewritestore.go b/pkg/epp/datastore/modelrewritestore.go new file mode 100644 index 000000000..2ad21d41f --- /dev/null +++ b/pkg/epp/datastore/modelrewritestore.go @@ -0,0 +1,187 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datastore + +import ( + "sort" + "time" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" +) + +// ModelRewriteStore encapsulates the logic for storing and retrieving +// InferenceModelRewrite rules, handling precedence correctly. This struct is not +// thread-safe; concurrency must be managed by its consumer. +type ModelRewriteStore struct { + genericRules []*rewriteRuleWithMetadata + rulesByExactModelMatch map[string][]*rewriteRuleWithMetadata + allReWrites map[types.NamespacedName]*v1alpha2.InferenceModelRewrite +} + +func NewModelRewriteStore() *ModelRewriteStore { + return &ModelRewriteStore{ + genericRules: []*rewriteRuleWithMetadata{}, + rulesByExactModelMatch: map[string][]*rewriteRuleWithMetadata{}, + allReWrites: map[types.NamespacedName]*v1alpha2.InferenceModelRewrite{}, + } +} + +// Set adds or updates an InferenceModelRewrite in the store. It deconstructs the +// object into individual rules and stores them in the appropriate data structures, +// ensuring they remain sorted by precedence. +func (ms *ModelRewriteStore) Set(infModelRewrite *v1alpha2.InferenceModelRewrite) { + nn := getNN(infModelRewrite) + + // If the rewrite object already exists, remove its old rules before adding new ones. + if _, ok := ms.allReWrites[nn]; ok { + ms.deleteInternal(nn) + } + ms.allReWrites[nn] = infModelRewrite + + for i := range infModelRewrite.Spec.Rules { + ruleWithMetadata := newRuleWithMetadata(infModelRewrite, i) + if ruleWithMetadata == nil { + continue + } + + if ruleWithMetadata.isGeneric() { + ms.genericRules = append(ms.genericRules, ruleWithMetadata) + } else { + for model := range ruleWithMetadata.exactModels() { + ms.rulesByExactModelMatch[model] = append(ms.rulesByExactModelMatch[model], ruleWithMetadata) + } + } + } + + // Sort all rule lists by timestamp to maintain precedence. + sort.Slice(ms.genericRules, func(i, j int) bool { + return ms.genericRules[i].createTimestamp.Before(ms.genericRules[j].createTimestamp) + }) + + for model := range ms.rulesByExactModelMatch { + sort.Slice(ms.rulesByExactModelMatch[model], func(i, j int) bool { + return ms.rulesByExactModelMatch[model][i].createTimestamp.Before(ms.rulesByExactModelMatch[model][j].createTimestamp) + }) + } +} + +// Delete removes an InferenceModelRewrite and all its associated rules from the store. +func (ms *ModelRewriteStore) Delete(nn types.NamespacedName) { + ms.deleteInternal(nn) +} + +// deleteInternal is the non-locking implementation for deleting a rewrite. +func (ms *ModelRewriteStore) deleteInternal(nn types.NamespacedName) { + if _, ok := ms.allReWrites[nn]; !ok { + return + } + delete(ms.allReWrites, nn) + + // Filter out the generic rules associated with the deleted rewrite. + newGenericRules := make([]*rewriteRuleWithMetadata, 0, len(ms.genericRules)) + for _, ruleWithMd := range ms.genericRules { + if ruleWithMd.parentNN() != nn { + newGenericRules = append(newGenericRules, ruleWithMd) + } + } + ms.genericRules = newGenericRules + + // Filter out the exact-match rules associated with the deleted rewrite. + for modelName, rulesWithMd := range ms.rulesByExactModelMatch { + newRules := make([]*rewriteRuleWithMetadata, 0, len(rulesWithMd)) + for _, r := range rulesWithMd { + if r.parentNN() != nn { + newRules = append(newRules, r) + } + } + + if len(newRules) == 0 { + delete(ms.rulesByExactModelMatch, modelName) + } else { + ms.rulesByExactModelMatch[modelName] = newRules + } + } +} + +// GetRule returns the single, highest-precedence rule for a given model name. +// It prioritizes exact matches over generic ones, and among those, the oldest rule wins. +func (ms *ModelRewriteStore) GetRule(modelName string) *v1alpha2.InferenceModelRewriteRule { + // Exact matches have the highest precedence. + if rulesWithMd, ok := ms.rulesByExactModelMatch[modelName]; ok && len(rulesWithMd) > 0 { + return &rulesWithMd[0].rule // The list is pre-sorted, so the first element is the oldest. + } + + // If no exact match, fall back to the oldest generic rule. + if len(ms.genericRules) > 0 { + return &ms.genericRules[0].rule // The list is pre-sorted. + } + return nil +} + +// GetAll returns a slice of all InferenceModelRewrite objects currently in the store. +func (ms *ModelRewriteStore) GetAll() []*v1alpha2.InferenceModelRewrite { + rewrites := make([]*v1alpha2.InferenceModelRewrite, 0, len(ms.allReWrites)) + for _, rewrite := range ms.allReWrites { + rewrites = append(rewrites, rewrite) + } + return rewrites +} + +func getNN(infModelRewrite *v1alpha2.InferenceModelRewrite) types.NamespacedName { + return types.NamespacedName{ + Namespace: infModelRewrite.Namespace, + Name: infModelRewrite.Name, + } +} + +// rewriteRuleWithMetadata decorates a rule with metadata from its parent object +// to be used in precedence sorting. +type rewriteRuleWithMetadata struct { + rule v1alpha2.InferenceModelRewriteRule + createTimestamp time.Time + parentRewriteNN types.NamespacedName +} + +func newRuleWithMetadata(infModelRewrite *v1alpha2.InferenceModelRewrite, ruleIdx int) *rewriteRuleWithMetadata { + if ruleIdx >= len(infModelRewrite.Spec.Rules) { + return nil + } + return &rewriteRuleWithMetadata{ + rule: infModelRewrite.Spec.Rules[ruleIdx], + createTimestamp: infModelRewrite.CreationTimestamp.Time, + parentRewriteNN: getNN(infModelRewrite), + } +} + +func (rr rewriteRuleWithMetadata) isGeneric() bool { + return len(rr.rule.Matches) == 0 +} + +func (rr rewriteRuleWithMetadata) exactModels() map[string]bool { + modelSet := map[string]bool{} + for _, match := range rr.rule.Matches { + if match.Model != nil { + modelSet[match.Model.Value] = true + } + } + return modelSet +} + +func (rr rewriteRuleWithMetadata) parentNN() types.NamespacedName { + return rr.parentRewriteNN +} diff --git a/pkg/epp/datastore/modelrewritestore_test.go b/pkg/epp/datastore/modelrewritestore_test.go new file mode 100644 index 000000000..41a1dc088 --- /dev/null +++ b/pkg/epp/datastore/modelrewritestore_test.go @@ -0,0 +1,185 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datastore + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" +) + +func TestModelRewriteStore(t *testing.T) { + now := time.Now() + oneMinuteAgo := now.Add(-1 * time.Minute) + + // Define common rules with generic names + ruleModel1V1 := v1alpha2.InferenceModelRewriteRule{ + Matches: []v1alpha2.Match{{Model: &v1alpha2.ModelMatch{Value: "model1"}}}, + Targets: []v1alpha2.TargetModel{{ModelRewrite: "model1-v1"}}, + } + ruleModel1V2 := v1alpha2.InferenceModelRewriteRule{ + Matches: []v1alpha2.Match{{Model: &v1alpha2.ModelMatch{Value: "model1"}}}, + Targets: []v1alpha2.TargetModel{{ModelRewrite: "model1-v2"}}, + } + ruleGeneric := v1alpha2.InferenceModelRewriteRule{ + Targets: []v1alpha2.TargetModel{{ModelRewrite: "generic-fallback"}}, + } + + // Define rewrite objects using plain structs + rewriteOld := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-old", Namespace: "default", CreationTimestamp: metav1.NewTime(oneMinuteAgo)}, + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{ruleModel1V1}}, + } + rewriteNew := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-new", Namespace: "default", CreationTimestamp: metav1.NewTime(now)}, + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{ruleModel1V2}}, + } + rewriteGenericOld := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-generic-old", Namespace: "default", CreationTimestamp: metav1.NewTime(oneMinuteAgo)}, + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{ruleGeneric}}, + } + rewriteGenericNew := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-generic-new", Namespace: "default", CreationTimestamp: metav1.NewTime(now)}, + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{{Targets: []v1alpha2.TargetModel{{ModelRewrite: "new-generic"}}}}}, + } + rewriteUpdated := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-old", Namespace: "default", CreationTimestamp: metav1.NewTime(now)}, // Same name as rewriteOld + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{ruleModel1V2}}, + } + + tests := []struct { + name string + initialState []*v1alpha2.InferenceModelRewrite + op func(store *ModelRewriteStore) + modelToGet string + wantRule *v1alpha2.InferenceModelRewriteRule + wantGetAll []*v1alpha2.InferenceModelRewrite + }{ + { + name: "Simple exact match", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + modelToGet: "model1", + wantRule: &ruleModel1V1, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + }, + { + name: "Simple generic match", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteGenericOld}, + modelToGet: "model2", // A different model to test generic fallback + wantRule: &ruleGeneric, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteGenericOld}, + }, + { + name: "No match", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + modelToGet: "model2", + wantRule: nil, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + }, + { + name: "Precedence: Exact match wins over generic", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + modelToGet: "model1", + wantRule: &ruleModel1V1, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + }, + { + name: "Precedence: Fallback to generic when no exact match", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + modelToGet: "model2", + wantRule: &ruleGeneric, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + }, + { + name: "Precedence: Oldest exact match wins", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteNew, rewriteOld}, + modelToGet: "model1", + wantRule: &ruleModel1V1, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteNew, rewriteOld}, + }, + { + name: "Precedence: Oldest generic match wins", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteGenericNew, rewriteGenericOld}, + modelToGet: "any-model", + wantRule: &ruleGeneric, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteGenericNew, rewriteGenericOld}, + }, + { + name: "Delete: successfully deletes a rewrite", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + op: func(store *ModelRewriteStore) { + store.Delete(getNN(rewriteOld)) + }, + modelToGet: "model1", + wantRule: &ruleGeneric, // Falls back to generic + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteGenericOld}, + }, + { + name: "Delete: non-existent rewrite does nothing", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + op: func(store *ModelRewriteStore) { + store.Delete(types.NamespacedName{Name: "non-existent", Namespace: "default"}) + }, + modelToGet: "model1", + wantRule: &ruleModel1V1, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + }, + { + name: "Update: Setting a rewrite with the same name replaces the old one", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + op: func(store *ModelRewriteStore) { + store.Set(rewriteUpdated) + }, + modelToGet: "model1", + wantRule: &ruleModel1V2, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteUpdated}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + store := NewModelRewriteStore() + for _, r := range tc.initialState { + store.Set(r) + } + + if tc.op != nil { + tc.op(store) + } + + gotRule := store.GetRule(tc.modelToGet) + if diff := cmp.Diff(tc.wantRule, gotRule); diff != "" { + t.Errorf("GetRule() mismatch (-want +got):\n%s", diff) + } + + if tc.wantGetAll != nil { + gotAll := store.GetAll() + if diff := cmp.Diff(tc.wantGetAll, gotAll, cmpopts.SortSlices(func(a, b *v1alpha2.InferenceModelRewrite) bool { + return getNN(a).String() < getNN(b).String() + })); diff != "" { + t.Errorf("GetAll() mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index c4f4f1c1b..3d13c0f7e 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -50,6 +50,7 @@ type Datastore interface { PoolGet() (*datalayer.EndpointPool, error) ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics + RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule } // Scheduler defines the interface required by the Director for scheduling. @@ -110,11 +111,16 @@ func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.R return infObjective } -// resolveTargetModel is a helper to update reqCtx with target model based on request. -func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { +// HandleRequest orchestrates the request lifecycle. +// It always returns the requestContext even in the error case, as the request context is used in error handling. +func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + logger := log.FromContext(ctx) + + // Parse Request, Resolve Target Models, and Determine Parameters requestBodyMap := reqCtx.Request.Body var ok bool reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string) + if !ok { return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"} } @@ -122,22 +128,11 @@ func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handler // Default to incoming model name reqCtx.TargetModelName = reqCtx.IncomingModelName } - reqCtx.Request.Body["model"] = reqCtx.TargetModelName - return reqCtx, nil -} -// HandleRequest orchestrates the request lifecycle. -// It always returns the requestContext even in the error case, as the request context is used in error handling. -func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx) + d.applyWeightedModelRewrite(reqCtx) - // Resolve target model and update req context. - reqCtx, err := d.resolveTargetModel(reqCtx) - if err != nil { - return reqCtx, err - } + reqCtx.Request.Body["model"] = reqCtx.TargetModelName - // Parse request body. requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body) if err != nil { return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()} @@ -198,6 +193,42 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo return reqCtx, nil } +func (d *Director) applyWeightedModelRewrite(reqCtx *handlers.RequestContext) { + rewriteRule := d.datastore.RewriteGet(reqCtx.IncomingModelName) + if rewriteRule == nil { + return + } + reqCtx.TargetModelName = d.selectWeightedModel(rewriteRule.Targets) +} + +func (d *Director) selectWeightedModel(models []v1alpha2.TargetModel) string { + if len(models) == 0 { + return "" + } + + var totalWeight int32 + for _, model := range models { + totalWeight += model.Weight + } + + if totalWeight == 0 { + // If total weight is 0, distribute evenly + return models[rand.Intn(len(models))].ModelRewrite + } + + randomNum := rand.Intn(int(totalWeight)) + var currentWeight int32 + for _, model := range models { + currentWeight += model.Weight + if randomNum < int(currentWeight) { + return model.ModelRewrite + } + } + + // Should not happen + return models[len(models)-1].ModelRewrite +} + // getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore. // according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies // a subset of endpoints, only these endpoints will be considered as candidates for the scheduler. diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index f361303c8..9fde14bc8 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "maps" + "sort" "testing" "time" @@ -87,7 +88,8 @@ func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMReques } type mockDatastore struct { - pods []backendmetrics.PodMetrics + pods []backendmetrics.PodMetrics + rewrites []*v1alpha2.InferenceModelRewrite } func (ds *mockDatastore) PoolGet() (*datalayer.EndpointPool, error) { @@ -167,6 +169,34 @@ func (m mockProducedDataType) Clone() datalayer.Cloneable { return mockProducedDataType{value: m.value} } +func (ds *mockDatastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { + // This mock implementation simulates the precedence logic for simplicity. + // It finds the oldest rewrite that has a rule matching the modelName. + var matchingRewrites []*v1alpha2.InferenceModelRewrite + for _, r := range ds.rewrites { + for _, rule := range r.Spec.Rules { + for _, match := range rule.Matches { + if match.Model != nil && match.Model.Value == modelName { + matchingRewrites = append(matchingRewrites, r) + break // break inner loop + } + } + } + } + + if len(matchingRewrites) == 0 { + return nil + } + + // Sort by timestamp to find the oldest. + sort.Slice(matchingRewrites, func(i, j int) bool { + return matchingRewrites[i].CreationTimestamp.Before(&matchingRewrites[j].CreationTimestamp) + }) + + // Return the first rule from the oldest rewrite. + return &matchingRewrites[0].Spec.Rules[0] +} + func TestDirector_HandleRequest(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) @@ -174,6 +204,8 @@ func TestDirector_HandleRequest(t *testing.T) { model := "food-review" modelSheddable := "food-review-sheddable" modelWithResolvedTarget := "food-review-resolve" + modelToBeRewritten := "food-review-to-be-rewritten" + modelRewritten := "food-review-rewritten" objectiveName := "ioFoodReview" objectiveNameSheddable := "imFoodReviewSheddable" @@ -191,6 +223,33 @@ func TestDirector_HandleRequest(t *testing.T) { CreationTimestamp(metav1.Unix(1000, 0)). Priority(1). ObjRef() + + rewrite := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-rule", + CreationTimestamp: metav1.Now(), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: modelToBeRewritten, + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: modelRewritten, + Weight: 100, + }, + }, + }, + }, + }, + } + pool := &v1.InferencePool{ ObjectMeta: metav1.ObjectMeta{Name: "test-pool", Namespace: "default"}, Spec: v1.InferencePoolSpec{ @@ -209,6 +268,7 @@ func TestDirector_HandleRequest(t *testing.T) { ds.ObjectiveSet(ioFoodReview) ds.ObjectiveSet(ioFoodReviewResolve) ds.ObjectiveSet(ioFoodReviewSheddable) + ds.RewriteSet(rewrite) scheme := runtime.NewScheme() _ = clientgoscheme.AddToScheme(scheme) @@ -284,6 +344,7 @@ func TestDirector_HandleRequest(t *testing.T) { mockAdmissionController *mockAdmissionController inferenceObjectiveName string schedulerMockSetup func(m *mockScheduler) + initialTargetModelName string // Initial target model in the reqCtx. wantErrCode string // Expected errutil code string wantReqCtx *handlers.RequestContext // Fields to check in the returned RequestContext wantMutatedBodyModel string // Expected model in reqCtx.Request.Body after PostDispatch @@ -301,6 +362,7 @@ func TestDirector_HandleRequest(t *testing.T) { schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: model, wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveName, TargetModelName: model, @@ -314,9 +376,31 @@ func TestDirector_HandleRequest(t *testing.T) { }, wantMutatedBodyModel: model, inferenceObjectiveName: objectiveName, - targetModelName: model, - }, - { + }, { + name: "successful request with model rewrite", + reqBodyMap: map[string]any{ + "model": modelToBeRewritten, + "prompt": "some prompt", + }, + mockAdmissionController: &mockAdmissionController{admitErr: nil}, + schedulerMockSetup: func(m *mockScheduler) { + m.scheduleResults = defaultSuccessfulScheduleResults + }, + initialTargetModelName: model, + wantReqCtx: &handlers.RequestContext{ + ObjectiveKey: model, + TargetModelName: modelRewritten, + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + Port: "8000", + MetricsHost: "192.168.1.100:8000", + }, + TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", + }, + wantMutatedBodyModel: modelRewritten, + inferenceObjectiveName: model, + }, { name: "successful chat completions request", reqBodyMap: map[string]any{ "model": model, @@ -331,6 +415,7 @@ func TestDirector_HandleRequest(t *testing.T) { schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: model, wantReqCtx: &handlers.RequestContext{ TargetModelName: model, TargetPod: &backend.Pod{ @@ -442,6 +527,7 @@ func TestDirector_HandleRequest(t *testing.T) { schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: model, wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveName, TargetModelName: model, @@ -453,11 +539,8 @@ func TestDirector_HandleRequest(t *testing.T) { }, TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", }, - wantMutatedBodyModel: model, inferenceObjectiveName: objectiveName, - targetModelName: model, - }, - { + }, { name: "successful request with target model resolution", reqBodyMap: map[string]any{ "model": modelWithResolvedTarget, @@ -467,6 +550,7 @@ func TestDirector_HandleRequest(t *testing.T) { schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: "resolved-target-model-A", wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveNameResolve, TargetModelName: "resolved-target-model-A", @@ -480,13 +564,13 @@ func TestDirector_HandleRequest(t *testing.T) { }, wantMutatedBodyModel: "resolved-target-model-A", inferenceObjectiveName: objectiveNameResolve, - targetModelName: "resolved-target-model-A", }, { name: "nonexistent target defined, use default inference model", schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: "food-review-1", wantReqCtx: &handlers.RequestContext{ ObjectiveKey: "food-review-1", TargetModelName: "food-review-1", @@ -505,10 +589,8 @@ func TestDirector_HandleRequest(t *testing.T) { }, mockAdmissionController: &mockAdmissionController{admitErr: nil}, inferenceObjectiveName: "food-review-1", - targetModelName: "food-review-1", }, { - name: "request rejected by admission controller", reqBodyMap: map[string]any{ "model": modelSheddable, @@ -578,6 +660,9 @@ func TestDirector_HandleRequest(t *testing.T) { } config = config.WithAdmissionPlugins(newMockAdmissionPlugin("test-admit-plugin", test.admitRequestDenialError)) director := NewDirectorWithConfig(ds, mockSched, test.mockAdmissionController, config) + if test.name == "successful request with model rewrite" { + director.datastore = &mockDatastore{pods: ds.PodList(backendmetrics.AllPodsPredicate), rewrites: []*v1alpha2.InferenceModelRewrite{rewrite}} + } reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -588,7 +673,7 @@ func TestDirector_HandleRequest(t *testing.T) { }, }, ObjectiveKey: test.inferenceObjectiveName, - TargetModelName: test.targetModelName, + TargetModelName: test.initialTargetModelName, } // Deep copy the body map. maps.Copy(reqCtx.Request.Body, test.reqBodyMap) @@ -777,6 +862,266 @@ func TestGetRandomPod(t *testing.T) { } } +func TestDirector_ApplyWeightedModelRewrite(t *testing.T) { + _ = logutil.NewTestLoggerIntoContext(context.Background()) + + // Mock InferenceModelRewrite objects + rewriteOld := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-old", + CreationTimestamp: metav1.Unix(1000, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: "model-a", + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: "model-a-old-tuned", + Weight: 100, + }, + }, + }, + }, + }, + } + + rewriteNew := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-new", + CreationTimestamp: metav1.Unix(2000, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: "model-a", + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: "model-a-new-tuned", + Weight: 100, + }, + }, + }, + }, + }, + } + + rewriteB := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-b", + CreationTimestamp: metav1.Unix(1500, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: "model-b", + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: "model-b-tuned", + Weight: 100, + }, + }, + }, + }, + }, + } + + rewriteWeighted := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-weighted", + CreationTimestamp: metav1.Unix(1200, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: "model-c", + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: "model-c-v1", + Weight: 70, + }, + { + ModelRewrite: "model-c-v2", + Weight: 30, + }, + }, + }, + }, + }, + } + + tests := []struct { + name string + rewrites []*v1alpha2.InferenceModelRewrite + incomingModel string + expectedTarget []string + initialTarget string // Initial value of reqCtx.TargetModelName + }{ + { + name: "no rewrites", + rewrites: []*v1alpha2.InferenceModelRewrite{}, + incomingModel: "model-x", + expectedTarget: []string{"model-x"}, + initialTarget: "model-x", + }, + { + name: "single matching rewrite", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteB}, + incomingModel: "model-b", + expectedTarget: []string{"model-b-tuned"}, + initialTarget: "model-b", + }, + { + name: "no matching rewrite", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteB}, + incomingModel: "model-x", + expectedTarget: []string{"model-x"}, + initialTarget: "model-x", + }, + { + name: "oldest rewrite wins for duplicate model", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteNew, rewriteOld}, // New is first, but Old has older timestamp + incomingModel: "model-a", + expectedTarget: []string{"model-a-old-tuned"}, + initialTarget: "model-a", + }, + { + name: "weighted rewrite applied (probabilistic check)", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteWeighted}, + incomingModel: "model-c", + initialTarget: "model-c", + expectedTarget: []string{"model-c-v1", "model-c-v2"}, + }, + { + name: "initial TargetModelName is respected if no rewrite matches", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteB}, + incomingModel: "model-x", + initialTarget: "pre-existing-target", + expectedTarget: []string{"pre-existing-target"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mockDs := &mockDatastore{rewrites: test.rewrites} + director := NewDirectorWithConfig(mockDs, &mockScheduler{}, &mockAdmissionController{}, NewConfig()) + + reqCtx := &handlers.RequestContext{ + IncomingModelName: test.incomingModel, + TargetModelName: test.initialTarget, + } + + director.applyWeightedModelRewrite(reqCtx) + assert.Contains(t, test.expectedTarget, reqCtx.TargetModelName, "TargetModelName mismatch") + }) + } +} + +func TestDirector_SelectWeightedModel(t *testing.T) { + tests := []struct { + name string + targets []v1alpha2.TargetModel + possibleModels map[string]bool // For probabilistic cases + }{ + { + name: "single target", + targets: []v1alpha2.TargetModel{ + {ModelRewrite: "model-a", Weight: 100}, + }, + possibleModels: map[string]bool{"model-a": true}, + }, + { + name: "multiple targets, equal weight", + targets: []v1alpha2.TargetModel{ + {ModelRewrite: "model-a", Weight: 50}, + {ModelRewrite: "model-b", Weight: 50}, + }, + possibleModels: map[string]bool{"model-a": true, "model-b": true}, + }, + { + name: "multiple targets, different weights", + targets: []v1alpha2.TargetModel{ + {ModelRewrite: "model-x", Weight: 70}, + {ModelRewrite: "model-y", Weight: 30}, + }, + possibleModels: map[string]bool{"model-x": true, "model-y": true}, + }, + { + name: "zero total weight, distribute evenly", + targets: []v1alpha2.TargetModel{ + {ModelRewrite: "model-z1", Weight: 0}, + {ModelRewrite: "model-z2", Weight: 0}, + }, + possibleModels: map[string]bool{"model-z1": true, "model-z2": true}, + }, + } + + director := &Director{} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Run multiple times to check distribution + counter := make(map[string]int) + numRuns := 1000 + for i := 0; i < numRuns; i++ { + selected := director.selectWeightedModel(test.targets) + counter[selected]++ + } + + // Assert that all selected models are within the possible models + for model := range counter { + if _, ok := test.possibleModels[model]; !ok { + t.Errorf("Selected model %s is not in possible models %v", model, test.possibleModels) + } + } + + // Basic check for distribution (e.g., if 70/30, expect roughly 700/300) + if len(test.targets) > 1 { + totalWeight := int32(0) + for _, target := range test.targets { + totalWeight += target.Weight + } + + if totalWeight == 0 { // Special case for zero total weight + for _, target := range test.targets { + expectedCount := numRuns / len(test.targets) + assert.InDelta(t, expectedCount, counter[target.ModelRewrite], float64(numRuns)/float64(len(test.targets))*0.2, "Distribution for %s is off", target.ModelRewrite) + } + } else { + for _, target := range test.targets { + expectedCount := float64(numRuns) * (float64(target.Weight) / float64(totalWeight)) + assert.InDelta(t, expectedCount, float64(counter[target.ModelRewrite]), expectedCount*0.2, "Distribution for %s is off", target.ModelRewrite) + } + } + } + }) + } +} + func TestDirector_HandleResponseReceived(t *testing.T) { pr1 := newTestResponseReceived("pr1") diff --git a/pkg/epp/server/controller_manager.go b/pkg/epp/server/controller_manager.go index c82b0bcb9..acc0bd51b 100644 --- a/pkg/epp/server/controller_manager.go +++ b/pkg/epp/server/controller_manager.go @@ -55,6 +55,16 @@ func defaultManagerOptions(disableK8sCrdReconcile bool, gknn common.GKNN, metric gknn.Namespace: {}, }, }, + &v1alpha2.InferenceObjective{}: { + Namespaces: map[string]cache.Config{ + gknn.Namespace: {}, + }, + }, + &v1alpha2.InferenceModelRewrite{}: { + Namespaces: map[string]cache.Config{ + gknn.Namespace: {}, + }, + }, }, }, Metrics: metricsServerOptions, diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index e43d84923..ae506a78d 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -131,6 +131,14 @@ func (r *ExtProcServerRunner) SetupWithManager(ctx context.Context, mgr ctrl.Man } } + if err := (&controller.InferenceModelRewriteReconciler{ + Datastore: r.Datastore, + Reader: mgr.GetClient(), + PoolGKNN: r.PoolGKNN, + }).SetupWithManager(ctx, mgr); err != nil { + return fmt.Errorf("failed setting up InferenceModelRewriteReconciler: %w", err) + } + if err := (&controller.PodReconciler{ Datastore: r.Datastore, Reader: mgr.GetClient(),