diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 2f5c1e388..24a59ebf7 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -25,7 +25,10 @@ import ( "net/http" "net/http/pprof" "os" + "regexp" "runtime" + "strconv" + "strings" "sync/atomic" "github.com/go-logr/logr" @@ -34,16 +37,18 @@ import ( "go.uber.org/zap/zapcore" "google.golang.org/grpc" healthPb "google.golang.org/grpc/health/grpc_health_v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/rest" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/metrics/filters" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" - configapi "sigs.k8s.io/gateway-api-inference-extension/apix/config/v1alpha1" "sigs.k8s.io/gateway-api-inference-extension/internal/runnable" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" @@ -111,6 +116,8 @@ var ( poolName = flag.String("pool-name", runserver.DefaultPoolName, "Name of the InferencePool this Endpoint Picker is associated with.") poolGroup = flag.String("pool-group", runserver.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.") poolNamespace = flag.String("pool-namespace", "", "Namespace of the InferencePool this Endpoint Picker is associated with.") + endpointSelector = flag.String("endpoint-selector", "", "selector to filter model server pods on, only key=value paris is supported. Format: a comma-separated list of key value paris, e.g., 'app=vllm-llama3-8b-instruct,env=prod'.") + endpointTargetPorts = flag.String("endpoint-target-ports", "", "target ports of model server pods. Format: a comma-separated list of numbers, e.g., '3000,3001,3002'") logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity") secureServing = flag.Bool("secure-serving", runserver.DefaultSecureServing, "Enables secure serving. Defaults to true.") healthChecking = flag.Bool("health-checking", runserver.DefaultHealthChecking, "Enables health checking") @@ -231,16 +238,26 @@ func (r *Runner) Run(ctx context.Context) error { if err != nil { return err } - datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort)) - eppConfig, err := r.parseConfigurationPhaseTwo(ctx, rawConfig, datastore) + gknn, err := extractGKNN(*poolName, *poolGroup, *poolNamespace, *endpointSelector) + if err != nil { + setupLog.Error(err, "Failed to extract GKNN") + return err + } + disableK8sCrdReconcile := *endpointSelector != "" + ds, err := setupDatastore(setupLog, ctx, epf, int32(*modelServerMetricsPort), disableK8sCrdReconcile, *poolName, *poolNamespace, *endpointSelector, *endpointTargetPorts) + if err != nil { + setupLog.Error(err, "Failed to setup datastore") + return err + } + eppConfig, err := r.parseConfigurationPhaseTwo(ctx, rawConfig, ds) if err != nil { setupLog.Error(err, "Failed to parse configuration") return err } // --- Setup Metrics Server --- - r.customCollectors = append(r.customCollectors, collectors.NewInferencePoolMetricsCollector(datastore)) + r.customCollectors = append(r.customCollectors, collectors.NewInferencePoolMetricsCollector(ds)) metrics.Register(r.customCollectors...) metrics.RecordInferenceExtensionInfo(version.CommitSHA, version.BuildRef) // Register metrics handler. @@ -259,34 +276,10 @@ func (r *Runner) Run(ctx context.Context) error { }(), } - // Determine pool namespace: if --pool-namespace is non-empty, use it; else NAMESPACE env var; else default - resolvePoolNamespace := func() string { - if *poolNamespace != "" { - return *poolNamespace - } - if nsEnv := os.Getenv("NAMESPACE"); nsEnv != "" { - return nsEnv - } - return runserver.DefaultPoolNamespace - } - resolvedPoolNamespace := resolvePoolNamespace() - poolNamespacedName := types.NamespacedName{ - Name: *poolName, - Namespace: resolvedPoolNamespace, - } - poolGroupKind := schema.GroupKind{ - Group: *poolGroup, - Kind: "InferencePool", - } - poolGKNN := common.GKNN{ - NamespacedName: poolNamespacedName, - GroupKind: poolGroupKind, - } - isLeader := &atomic.Bool{} isLeader.Store(false) - mgr, err := runserver.NewDefaultManager(poolGKNN, cfg, metricsServerOptions, *haEnableLeaderElection) + mgr, err := runserver.NewDefaultManager(disableK8sCrdReconcile, *gknn, cfg, metricsServerOptions, *haEnableLeaderElection) if err != nil { setupLog.Error(err, "Failed to create controller manager") return err @@ -353,14 +346,18 @@ func (r *Runner) Run(ctx context.Context) error { admissionController = requestcontrol.NewLegacyAdmissionController(saturationDetector) } - director := requestcontrol.NewDirectorWithConfig(datastore, scheduler, admissionController, r.requestControlConfig) + director := requestcontrol.NewDirectorWithConfig( + ds, + scheduler, + admissionController, + r.requestControlConfig) // --- Setup ExtProc Server Runner --- serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, - PoolNamespacedName: poolNamespacedName, - PoolGKNN: poolGKNN, - Datastore: datastore, + GKNN: *gknn, + Datastore: ds, + DisableK8sCrdReconcile: disableK8sCrdReconcile, SecureServing: *secureServing, HealthChecking: *healthChecking, CertPath: *certPath, @@ -377,7 +374,7 @@ func (r *Runner) Run(ctx context.Context) error { // --- Add Runnables to Manager --- // Register health server. - if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), datastore, *grpcHealthPort, isLeader, *haEnableLeaderElection); err != nil { + if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), ds, *grpcHealthPort, isLeader, *haEnableLeaderElection); err != nil { return err } @@ -397,6 +394,28 @@ func (r *Runner) Run(ctx context.Context) error { return nil } +func setupDatastore(setupLog logr.Logger, ctx context.Context, epFactory datalayer.EndpointFactory, modelServerMetricsPort int32, disableK8sCrdReconcile bool, namespace, name, endpointSelector, endpointTargetPorts string) (datastore.Datastore, error) { + if !disableK8sCrdReconcile { + return datastore.NewDatastore(ctx, epFactory, modelServerMetricsPort), nil + } else { + endpointPool := datalayer.NewEndpointPool(namespace, name) + labelsMap, err := labels.ConvertSelectorToLabelsMap(endpointSelector) + if err != nil { + setupLog.Error(err, "Failed to parse flag %q with error: %w", "endpoint-selector", err) + return nil, err + } + endpointPool.Selector = labelsMap + endpointPool.TargetPorts, err = strToUniqueIntSlice(endpointTargetPorts) + if err != nil { + setupLog.Error(err, "Failed to parse flag %q with error: %w", "endpoint-target-ports", err) + return nil, err + } + + endpointPoolOption := datastore.WithEndpointPool(endpointPool) + return datastore.NewDatastore(ctx, epFactory, modelServerMetricsPort, endpointPoolOption), nil + } +} + // registerInTreePlugins registers the factory functions of all known plugins func (r *Runner) registerInTreePlugins() { plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory) @@ -635,9 +654,19 @@ func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds datastore. } func validateFlags() error { - if *poolName == "" { - return fmt.Errorf("required %q flag not set", "poolName") + if (*poolName != "" && *endpointSelector != "") || (*poolName == "" && *endpointSelector == "") { + return errors.New("either pool-name or endpoint-selector must be set") } + if *endpointSelector != "" { + targetPortsList, err := strToUniqueIntSlice(*endpointTargetPorts) + if err != nil { + return fmt.Errorf("unexpected value for %q flag with error %w", "endpoint-target-ports", err) + } + if len(targetPortsList) == 0 || len(targetPortsList) > 8 { + return fmt.Errorf("flag %q should have length from 1 to 8", "endpoint-target-ports") + } + } + if *configText != "" && *configFile != "" { return fmt.Errorf("both the %q and %q flags can not be set at the same time", "configText", "configFile") } @@ -648,6 +677,34 @@ func validateFlags() error { return nil } +func strToUniqueIntSlice(s string) ([]int, error) { + seen := sets.NewInt() + var intList []int + + if s == "" { + return intList, nil + } + + strList := strings.Split(s, ",") + + for _, str := range strList { + trimmedStr := strings.TrimSpace(str) + if trimmedStr == "" { + continue + } + portInt, err := strconv.Atoi(trimmedStr) + if err != nil { + return nil, fmt.Errorf("invalid number: '%s' is not an integer", trimmedStr) + } + + if _, ok := seen[portInt]; !ok { + seen[portInt] = struct{}{} + intList = append(intList, portInt) + } + } + return intList, nil +} + func verifyMetricMapping(mapping backendmetrics.MetricMapping, logger logr.Logger) { if mapping.TotalQueuedRequests == nil { logger.Info("Not scraping metric: TotalQueuedRequests") @@ -683,3 +740,62 @@ func setupPprofHandlers(mgr ctrl.Manager) error { } return nil } + +func extractDeploymentName(podName string) (string, error) { + regex := regexp.MustCompile(`^(.+)-[a-z0-9]+-[a-z0-9]+$`) + + matches := regex.FindStringSubmatch(podName) + if len(matches) == 2 { + return matches[1], nil + } + return "", fmt.Errorf("failed to parse deployment name from pod name %s", podName) +} + +func extractGKNN(poolName, poolGroup, poolNamespace, endpointSelector string) (*common.GKNN, error) { + if poolName != "" { + // Determine pool namespace: if --pool-namespace is non-empty, use it; else NAMESPACE env var; else default + resolvedPoolNamespace := resolvePoolNamespace(poolNamespace) + poolNamespacedName := types.NamespacedName{ + Name: poolName, + Namespace: resolvedPoolNamespace, + } + poolGroupKind := schema.GroupKind{ + Group: poolGroup, + Kind: "InferencePool", + } + return &common.GKNN{ + NamespacedName: poolNamespacedName, + GroupKind: poolGroupKind, + }, nil + } + + if endpointSelector != "" { + // Determine EPP namespace: NAMESPACE env var; else default + resolvedPoolNamespace := resolvePoolNamespace(poolNamespace) + // Determine EPP name: POD_NAME env var + eppPodNameEnv := os.Getenv("POD_NAME") + if eppPodNameEnv == "" { + return nil, errors.New("failed to get environment variable POD_NAME") + + } + eppName, err := extractDeploymentName(eppPodNameEnv) + if err != nil { + return nil, err + } + return &common.GKNN{ + NamespacedName: types.NamespacedName{Namespace: resolvedPoolNamespace, Name: eppName}, + GroupKind: schema.GroupKind{Kind: "Deployment", Group: "apps"}, + }, nil + } + return nil, errors.New("can't construct gknn as both pool-name and endpoint-selector are missing") +} + +func resolvePoolNamespace(poolNamespace string) string { + if poolNamespace != "" { + return poolNamespace + } + if nsEnv := os.Getenv("NAMESPACE"); nsEnv != "" { + return nsEnv + } + return runserver.DefaultPoolNamespace +} diff --git a/config/charts/inferencepool/templates/epp-deployment.yaml b/config/charts/inferencepool/templates/epp-deployment.yaml index 5b3634c2a..be6a39ead 100644 --- a/config/charts/inferencepool/templates/epp-deployment.yaml +++ b/config/charts/inferencepool/templates/epp-deployment.yaml @@ -147,6 +147,10 @@ spec: valueFrom: fieldRef: fieldPath: metadata.namespace + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name {{- if .Values.inferenceExtension.tracing.enabled }} - name: OTEL_SERVICE_NAME value: "gateway-api-inference-extension" diff --git a/pkg/epp/backend/metrics/pod_metrics_test.go b/pkg/epp/backend/metrics/pod_metrics_test.go index b0297cd1e..8a5561c0e 100644 --- a/pkg/epp/backend/metrics/pod_metrics_test.go +++ b/pkg/epp/backend/metrics/pod_metrics_test.go @@ -25,7 +25,6 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/types" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" ) @@ -86,8 +85,8 @@ func TestMetricsRefresh(t *testing.T) { type fakeDataStore struct{} -func (f *fakeDataStore) PoolGet() (*v1.InferencePool, error) { - return &v1.InferencePool{Spec: v1.InferencePoolSpec{TargetPorts: []v1.Port{{Number: 8000}}}}, nil +func (f *fakeDataStore) PoolGet() (*datalayer.EndpointPool, error) { + return &datalayer.EndpointPool{}, nil } func (f *fakeDataStore) PodList(func(PodMetrics) bool) []PodMetrics { diff --git a/pkg/epp/controller/inferenceobjective_reconciler.go b/pkg/epp/controller/inferenceobjective_reconciler.go index c8ac5a6c3..49ddf70b0 100644 --- a/pkg/epp/controller/inferenceobjective_reconciler.go +++ b/pkg/epp/controller/inferenceobjective_reconciler.go @@ -21,12 +21,12 @@ import ( "fmt" "k8s.io/apimachinery/pkg/api/errors" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/predicate" - "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" @@ -55,7 +55,7 @@ func (c *InferenceObjectiveReconciler) Reconcile(ctx context.Context, req ctrl.R } if notFound || !infObjective.DeletionTimestamp.IsZero() || infObjective.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) || infObjective.Spec.PoolRef.Group != v1alpha2.Group(c.PoolGKNN.Group) { - // InferenceObjective object got deleted or changed the referenced pool. + // InferenceObjective object got deleted or changed the referenced inferencePool. c.Datastore.ObjectiveDelete(req.NamespacedName) return ctrl.Result{}, nil } diff --git a/pkg/epp/controller/inferenceobjective_reconciler_test.go b/pkg/epp/controller/inferenceobjective_reconciler_test.go index 4ceff5d07..f8d48eca9 100644 --- a/pkg/epp/controller/inferenceobjective_reconciler_test.go +++ b/pkg/epp/controller/inferenceobjective_reconciler_test.go @@ -27,25 +27,26 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + poolutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) var ( - pool = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef() + inferencePool = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef() infObjective1 = utiltest.MakeInferenceObjective("model1"). - Namespace(pool.Namespace). + Namespace(inferencePool.Namespace). Priority(1). CreationTimestamp(metav1.Unix(1000, 0)). - PoolName(pool.Name). + PoolName(inferencePool.Name). PoolGroup("inference.networking.k8s.io").ObjRef() infObjective1Pool2 = utiltest.MakeInferenceObjective(infObjective1.Name). Namespace(infObjective1.Namespace). @@ -57,24 +58,24 @@ var ( Namespace(infObjective1.Namespace). Priority(2). CreationTimestamp(metav1.Unix(1003, 0)). - PoolName(pool.Name). + PoolName(inferencePool.Name). PoolGroup("inference.networking.k8s.io").ObjRef() infObjective1Deleted = utiltest.MakeInferenceObjective(infObjective1.Name). Namespace(infObjective1.Namespace). CreationTimestamp(metav1.Unix(1004, 0)). DeletionTimestamp(). - PoolName(pool.Name). + PoolName(inferencePool.Name). PoolGroup("inference.networking.k8s.io").ObjRef() infObjective1DiffGroup = utiltest.MakeInferenceObjective(infObjective1.Name). - Namespace(pool.Namespace). + Namespace(inferencePool.Namespace). Priority(1). CreationTimestamp(metav1.Unix(1005, 0)). - PoolName(pool.Name). + PoolName(inferencePool.Name). PoolGroup("inference.networking.x-k8s.io").ObjRef() infObjective2 = utiltest.MakeInferenceObjective("model2"). - Namespace(pool.Namespace). + Namespace(inferencePool.Namespace). CreationTimestamp(metav1.Unix(1000, 0)). - PoolName(pool.Name). + PoolName(inferencePool.Name). PoolGroup("inference.networking.k8s.io").ObjRef() ) @@ -120,7 +121,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) { { name: "Objective not found, no matching existing objective to delete", objectivessInStore: []*v1alpha2.InferenceObjective{infObjective1}, - incomingReq: &types.NamespacedName{Name: "non-existent-objective", Namespace: pool.Namespace}, + incomingReq: &types.NamespacedName{Name: "non-existent-objective", Namespace: inferencePool.Namespace}, wantObjectives: []*v1alpha2.InferenceObjective{infObjective1}, }, { @@ -130,13 +131,13 @@ func TestInferenceObjectiveReconciler(t *testing.T) { wantObjectives: []*v1alpha2.InferenceObjective{infObjective1, infObjective2}, }, { - name: "Objective deleted due to group mismatch for the inference pool", + name: "Objective deleted due to group mismatch for the inference inferencePool", objectivessInStore: []*v1alpha2.InferenceObjective{infObjective1}, objective: infObjective1DiffGroup, wantObjectives: []*v1alpha2.InferenceObjective{}, }, { - name: "Objective ignored due to group mismatch for the inference pool", + name: "Objective ignored due to group mismatch for the inference inferencePool", objective: infObjective1DiffGroup, wantObjectives: []*v1alpha2.InferenceObjective{}, }, @@ -164,13 +165,14 @@ func TestInferenceObjectiveReconciler(t *testing.T) { for _, m := range test.objectivessInStore { ds.ObjectiveSet(m) } - _ = ds.PoolSet(context.Background(), fakeClient, pool) + endpointPool := poolutil.InferencePoolToEndpointPool(inferencePool) + _ = ds.PoolSet(context.Background(), fakeClient, endpointPool) reconciler := &InferenceObjectiveReconciler{ Reader: fakeClient, Datastore: ds, PoolGKNN: common.GKNN{ - NamespacedName: types.NamespacedName{Name: pool.Name, Namespace: pool.Namespace}, - GroupKind: schema.GroupKind{Group: pool.GroupVersionKind().Group, Kind: pool.GroupVersionKind().Kind}, + NamespacedName: types.NamespacedName{Name: inferencePool.Name, Namespace: inferencePool.Namespace}, + GroupKind: schema.GroupKind{Group: inferencePool.GroupVersionKind().Group, Kind: inferencePool.GroupVersionKind().Kind}, }, } if test.incomingReq == nil { @@ -190,8 +192,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) { if len(test.wantObjectives) != len(ds.ObjectiveGetAll()) { t.Errorf("Unexpected; want: %d, got:%d", len(test.wantObjectives), len(ds.ObjectiveGetAll())) } - - if diff := diffStore(ds, diffStoreParams{wantPool: pool, wantObjectives: test.wantObjectives}); diff != "" { + if diff := diffStore(ds, diffStoreParams{wantPool: endpointPool, wantObjectives: test.wantObjectives}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } diff --git a/pkg/epp/controller/inferencepool_reconciler.go b/pkg/epp/controller/inferencepool_reconciler.go index 3b52de0ae..400ce4392 100644 --- a/pkg/epp/controller/inferencepool_reconciler.go +++ b/pkg/epp/controller/inferencepool_reconciler.go @@ -21,20 +21,22 @@ import ( "fmt" "k8s.io/apimachinery/pkg/api/errors" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + pooltuil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" ) // InferencePoolReconciler utilizes the controller runtime to reconcile Instance Gateway resources // This implementation is just used for reading & maintaining data sync. The Gateway implementation -// will have the proper controller that will create/manage objects on behalf of the server pool. +// will have the proper controller that will create/manage objects on behalf of the server inferencePool. type InferencePoolReconciler struct { client.Reader Datastore datastore.Datastore @@ -75,25 +77,17 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques c.Datastore.Clear() return ctrl.Result{}, nil } - - // 4. Convert the fetched object to the canonical v1.InferencePool. - v1infPool := &v1.InferencePool{} - + var endpointPool *datalayer.EndpointPool switch pool := obj.(type) { case *v1.InferencePool: - // If it's already a v1 object, just use it. - v1infPool = pool + endpointPool = pooltuil.InferencePoolToEndpointPool(pool) case *v1alpha2.InferencePool: - var err error - err = pool.ConvertTo(v1infPool) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to convert XInferencePool to InferencePool - %w", err) - } + endpointPool = pooltuil.AlphaInferencePoolToEndpointPool(pool) default: return ctrl.Result{}, fmt.Errorf("unsupported API group: %s", c.PoolGKNN.Group) } - if err := c.Datastore.PoolSet(ctx, c.Reader, v1infPool); err != nil { + if err := c.Datastore.PoolSet(ctx, c.Reader, endpointPool); err != nil { return ctrl.Result{}, fmt.Errorf("failed to update datastore - %w", err) } diff --git a/pkg/epp/controller/inferencepool_reconciler_test.go b/pkg/epp/controller/inferencepool_reconciler_test.go index a2bce1256..0fdecb674 100644 --- a/pkg/epp/controller/inferencepool_reconciler_test.go +++ b/pkg/epp/controller/inferencepool_reconciler_test.go @@ -24,20 +24,21 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -121,49 +122,51 @@ func TestInferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffStore(ds, diffStoreParams{wantPool: pool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" { + endpointPool1 := pool.InferencePoolToEndpointPool(pool1) + if diff := diffStore(ds, diffStoreParams{wantPool: endpointPool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } newPool1 := &v1.InferencePool{} if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { - t.Errorf("Unexpected pool get error: %v", err) + t.Errorf("Unexpected inferencePool get error: %v", err) } newPool1.Spec.Selector = v1.LabelSelector{ MatchLabels: map[v1.LabelKey]v1.LabelValue{"app": "vllm_v2"}, } if err := fakeClient.Update(ctx, newPool1, &client.UpdateOptions{}); err != nil { - t.Errorf("Unexpected pool update error: %v", err) + t.Errorf("Unexpected inferencePool update error: %v", err) } - if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffStore(ds, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { + newEndpointPool1 := pool.InferencePoolToEndpointPool(newPool1) + if diff := diffStore(ds, diffStoreParams{wantPool: newEndpointPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } - // Step 3: update the pool port + // Step 3: update the inferencePool port if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { - t.Errorf("Unexpected pool get error: %v", err) + t.Errorf("Unexpected inferencePool get error: %v", err) } newPool1.Spec.TargetPorts = []v1.Port{{Number: 9090}} if err := fakeClient.Update(ctx, newPool1, &client.UpdateOptions{}); err != nil { - t.Errorf("Unexpected pool update error: %v", err) + t.Errorf("Unexpected inferencePool update error: %v", err) } if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffStore(ds, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { + newEndpointPool1 = pool.InferencePoolToEndpointPool(newPool1) + if diff := diffStore(ds, diffStoreParams{wantPool: newEndpointPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } - // Step 4: delete the pool to trigger a datastore clear + // Step 4: delete the inferencePool to trigger a datastore clear if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { - t.Errorf("Unexpected pool get error: %v", err) + t.Errorf("Unexpected inferencePool get error: %v", err) } if err := fakeClient.Delete(ctx, newPool1, &client.DeleteOptions{}); err != nil { - t.Errorf("Unexpected pool delete error: %v", err) + t.Errorf("Unexpected inferencePool delete error: %v", err) } if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) @@ -174,17 +177,15 @@ func TestInferencePoolReconciler(t *testing.T) { } type diffStoreParams struct { - wantPool *v1.InferencePool + wantPool *datalayer.EndpointPool wantPods []string wantObjectives []*v1alpha2.InferenceObjective } func diffStore(datastore datastore.Datastore, params diffStoreParams) string { gotPool, _ := datastore.PoolGet() - // controller-runtime fake client may not populate TypeMeta (APIVersion/Kind). - // Ignore it when comparing pools. - if diff := cmp.Diff(params.wantPool, gotPool, cmpopts.IgnoreTypes(metav1.TypeMeta{})); diff != "" { - return "pool:" + diff + if diff := cmp.Diff(params.wantPool, gotPool); diff != "" { + return "inferencePool:" + diff } // Default wantPods if not set because PodGetAll returns an empty slice when empty. @@ -268,79 +269,73 @@ func TestXInferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: pool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" { + endpointPool1 := pool.AlphaInferencePoolToEndpointPool(pool1) + if diff := xDiffStore(ds, xDiffStoreParams{wantPool: endpointPool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } newPool1 := &v1alpha2.InferencePool{} if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { - t.Errorf("Unexpected pool get error: %v", err) + t.Errorf("Unexpected inferencePool get error: %v", err) } newPool1.Spec.Selector = map[v1alpha2.LabelKey]v1alpha2.LabelValue{"app": "vllm_v2"} if err := fakeClient.Update(ctx, newPool1, &client.UpdateOptions{}); err != nil { - t.Errorf("Unexpected pool update error: %v", err) + t.Errorf("Unexpected inferencePool update error: %v", err) } if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { + newEndpointPool1 := pool.AlphaInferencePoolToEndpointPool(newPool1) + if diff := xDiffStore(ds, xDiffStoreParams{wantPool: newEndpointPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } - // Step 3: update the pool port + // Step 3: update the inferencePool port if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { - t.Errorf("Unexpected pool get error: %v", err) + t.Errorf("Unexpected inferencePool get error: %v", err) } newPool1.Spec.TargetPortNumber = 9090 if err := fakeClient.Update(ctx, newPool1, &client.UpdateOptions{}); err != nil { - t.Errorf("Unexpected pool update error: %v", err) + t.Errorf("Unexpected inferencePool update error: %v", err) } if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { + newEndpointPool1 = pool.AlphaInferencePoolToEndpointPool(newPool1) + if diff := xDiffStore(ds, xDiffStoreParams{wantPool: newEndpointPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } - // Step 4: delete the pool to trigger a datastore clear + // Step 4: delete the inferencePool to trigger a datastore clear if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { - t.Errorf("Unexpected pool get error: %v", err) + t.Errorf("Unexpected inferencePool get error: %v", err) } if err := fakeClient.Delete(ctx, newPool1, &client.DeleteOptions{}); err != nil { - t.Errorf("Unexpected pool delete error: %v", err) + t.Errorf("Unexpected inferencePool delete error: %v", err) } if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := xDiffStore(t, ds, xDiffStoreParams{wantPods: []string{}}); diff != "" { + if diff := xDiffStore(ds, xDiffStoreParams{wantPods: []string{}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } } type xDiffStoreParams struct { - wantPool *v1alpha2.InferencePool + wantPool *datalayer.EndpointPool wantPods []string wantObjectives []*v1alpha2.InferenceObjective } -func xDiffStore(t *testing.T, datastore datastore.Datastore, params xDiffStoreParams) string { +func xDiffStore(datastore datastore.Datastore, params xDiffStoreParams) string { gotPool, _ := datastore.PoolGet() if gotPool == nil && params.wantPool == nil { return "" } - gotXPool := &v1alpha2.InferencePool{} - - err := gotXPool.ConvertFrom(gotPool) - if err != nil { - t.Fatalf("failed to convert InferencePool to XInferencePool: %v", err) - } - - // controller-runtime fake client may not populate TypeMeta (APIVersion/Kind). - // Ignore it when comparing pools. - if diff := cmp.Diff(params.wantPool, gotXPool, cmpopts.IgnoreTypes(metav1.TypeMeta{})); diff != "" { - return "pool:" + diff + if diff := cmp.Diff(params.wantPool, gotPool); diff != "" { + return "inferencePool:" + diff } // Default wantPods if not set because PodGetAll returns an empty slice when empty. diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index b3a78ef92..06be5d785 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -23,12 +23,12 @@ import ( "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/predicate" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" diff --git a/pkg/epp/controller/pod_reconciler_test.go b/pkg/epp/controller/pod_reconciler_test.go index 28f817310..efdb36b25 100644 --- a/pkg/epp/controller/pod_reconciler_test.go +++ b/pkg/epp/controller/pod_reconciler_test.go @@ -28,13 +28,14 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -197,7 +198,7 @@ func TestPodReconciler(t *testing.T) { // Configure the initial state of the datastore. store := datastore.NewDatastore(t.Context(), pmf, 0) - _ = store.PoolSet(t.Context(), fakeClient, test.pool) + _ = store.PoolSet(t.Context(), fakeClient, pool.InferencePoolToEndpointPool(test.pool)) for _, pod := range test.existingPods { store.PodUpdateOrAddIfNotExist(pod) } diff --git a/pkg/epp/datalayer/endpoint_pool.go b/pkg/epp/datalayer/endpoint_pool.go new file mode 100644 index 000000000..05d91d451 --- /dev/null +++ b/pkg/epp/datalayer/endpoint_pool.go @@ -0,0 +1,34 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datalayer + +type EndpointPool struct { + Selector map[string]string + TargetPorts []int + Namespace string + Name string +} + +// NewEndpointPool creates and returns a new empty instance of EndpointPool. +func NewEndpointPool(namespace string, name string) *EndpointPool { + return &EndpointPool{ + Selector: make(map[string]string), + TargetPorts: []int{}, + Namespace: namespace, + Name: name, + } +} diff --git a/pkg/epp/datalayer/factory.go b/pkg/epp/datalayer/factory.go index 58da604a0..3a81763d5 100644 --- a/pkg/epp/datalayer/factory.go +++ b/pkg/epp/datalayer/factory.go @@ -23,8 +23,6 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" - - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" ) const ( @@ -40,7 +38,7 @@ const ( // - Global metrics logging uses PoolGet solely for error return and PodList to enumerate // all endpoints for metrics summarization. type PoolInfo interface { - PoolGet() (*v1.InferencePool, error) + PoolGet() (*EndpointPool, error) PodList(func(Endpoint) bool) []Endpoint } diff --git a/pkg/epp/datalayer/metrics/logger_test.go b/pkg/epp/datalayer/metrics/logger_test.go index 3ba2e7e84..4bf68cf0a 100644 --- a/pkg/epp/datalayer/metrics/logger_test.go +++ b/pkg/epp/datalayer/metrics/logger_test.go @@ -31,6 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log/zap" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" + poolutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" ) // Buffer to write the logs to @@ -95,8 +96,9 @@ var pod2 = &datalayer.PodInfo{ type fakeDataStore struct{} -func (f *fakeDataStore) PoolGet() (*v1.InferencePool, error) { - return &v1.InferencePool{Spec: v1.InferencePoolSpec{TargetPorts: []v1.Port{{Number: 8000}}}}, nil +func (f *fakeDataStore) PoolGet() (*datalayer.EndpointPool, error) { + pool := &v1.InferencePool{Spec: v1.InferencePoolSpec{TargetPorts: []v1.Port{{Number: 8000}}}} + return poolutil.InferencePoolToEndpointPool(pool), nil } func (f *fakeDataStore) PodList(predicate func(datalayer.Endpoint) bool) []datalayer.Endpoint { diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 5d6ff751b..2ab2e98cb 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -28,10 +28,9 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" - - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" @@ -49,8 +48,8 @@ type Datastore interface { // PoolSet sets the given pool in datastore. If the given pool has different label selector than the previous pool // that was stored, the function triggers a resync of the pods to keep the datastore updated. If the given pool // is nil, this call triggers the datastore.Clear() function. - PoolSet(ctx context.Context, reader client.Reader, pool *v1.InferencePool) error - PoolGet() (*v1.InferencePool, error) + PoolSet(ctx context.Context, reader client.Reader, endpointPool *datalayer.EndpointPool) error + PoolGet() (*datalayer.EndpointPool, error) PoolHasSynced() bool PoolLabelsMatch(podLabels map[string]string) bool @@ -69,15 +68,23 @@ type Datastore interface { Clear() } -func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory, modelServerMetricsPort int32) Datastore { +func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory, modelServerMetricsPort int32, opts ...DatastoreOption) Datastore { + // Initialize with defaults store := &datastore{ parentCtx: parentCtx, poolAndObjectivesMu: sync.RWMutex{}, + pool: nil, objectives: make(map[string]*v1alpha2.InferenceObjective), pods: &sync.Map{}, modelServerMetricsPort: modelServerMetricsPort, epf: epFactory, } + + // Apply options + for _, opt := range opts { + opt(store) + } + return store } @@ -86,8 +93,8 @@ type datastore struct { parentCtx context.Context // poolAndObjectivesMu is used to synchronize access to pool and the objectives map. poolAndObjectivesMu sync.RWMutex - pool *v1.InferencePool - // key: InferenceObjective name, value: *InferenceObjective + pool *datalayer.EndpointPool + // key: InferenceObjective.Spec.ModelName, value: *InferenceObjective objectives map[string]*v1alpha2.InferenceObjective // key: types.NamespacedName, value: backendmetrics.PodMetrics pods *sync.Map @@ -110,9 +117,9 @@ func (ds *datastore) Clear() { ds.pods.Clear() } -// /// InferencePool APIs /// -func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, pool *v1.InferencePool) error { - if pool == nil { +// /// Pool APIs /// +func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, endpointPool *datalayer.EndpointPool) error { + if endpointPool == nil { ds.Clear() return nil } @@ -120,10 +127,10 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, pool *v1 ds.poolAndObjectivesMu.Lock() defer ds.poolAndObjectivesMu.Unlock() - oldPool := ds.pool - ds.pool = pool - if oldPool == nil || !reflect.DeepEqual(pool.Spec.Selector, oldPool.Spec.Selector) { - logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", pool.Spec.Selector) + oldEndpointPool := ds.pool + ds.pool = endpointPool + if oldEndpointPool == nil || !reflect.DeepEqual(oldEndpointPool.Selector, endpointPool.Selector) { + logger.V(logutil.DEFAULT).Info("Updating endpoints", "selector", endpointPool.Selector) // A full resync is required to address two cases: // 1) At startup, the pod events may get processed before the pool is synced with the datastore, // and hence they will not be added to the store since pool selector is not known yet @@ -138,7 +145,7 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, pool *v1 return nil } -func (ds *datastore) PoolGet() (*v1.InferencePool, error) { +func (ds *datastore) PoolGet() (*datalayer.EndpointPool, error) { ds.poolAndObjectivesMu.RLock() defer ds.poolAndObjectivesMu.RUnlock() if !ds.PoolHasSynced() { @@ -159,7 +166,7 @@ func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { if ds.pool == nil { return false } - poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector.MatchLabels) + poolSelector := labels.SelectorFromSet(ds.pool.Selector) podSet := labels.Set(podLabels) return poolSelector.Matches(podSet) } @@ -225,14 +232,14 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { } modelServerMetricsPort := 0 - if len(ds.pool.Spec.TargetPorts) == 1 { + if len(ds.pool.TargetPorts) == 1 { modelServerMetricsPort = int(ds.modelServerMetricsPort) } pods := []*datalayer.PodInfo{} - for idx, port := range ds.pool.Spec.TargetPorts { + for idx, port := range ds.pool.TargetPorts { metricsPort := modelServerMetricsPort if metricsPort == 0 { - metricsPort = int(port.Number) + metricsPort = port } pods = append(pods, &datalayer.PodInfo{ @@ -242,7 +249,7 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { }, PodName: pod.Name, Address: pod.Status.PodIP, - Port: strconv.Itoa(int(port.Number)), + Port: strconv.Itoa(port), MetricsHost: net.JoinHostPort(pod.Status.PodIP, strconv.Itoa(metricsPort)), Labels: labels, }) @@ -280,7 +287,7 @@ func (ds *datastore) podResyncAll(ctx context.Context, reader client.Reader) err logger := log.FromContext(ctx) podList := &corev1.PodList{} if err := reader.List(ctx, podList, &client.ListOptions{ - LabelSelector: selectorFromInferencePoolSelector(ds.pool.Spec.Selector.MatchLabels), + LabelSelector: labels.SelectorFromSet(ds.pool.Selector), Namespace: ds.pool.Namespace, }); err != nil { return fmt.Errorf("failed to list pods - %w", err) @@ -313,14 +320,10 @@ func (ds *datastore) podResyncAll(ctx context.Context, reader client.Reader) err return nil } -func selectorFromInferencePoolSelector(selector map[v1.LabelKey]v1.LabelValue) labels.Selector { - return labels.SelectorFromSet(stripLabelKeyAliasFromLabelMap(selector)) -} +type DatastoreOption func(*datastore) -func stripLabelKeyAliasFromLabelMap(labels map[v1.LabelKey]v1.LabelValue) map[string]string { - outMap := make(map[string]string) - for k, v := range labels { - outMap[string(k)] = string(v) +func WithEndpointPool(pool *datalayer.EndpointPool) DatastoreOption { + return func(d *datastore) { + d.pool = pool } - return outMap } diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index ee59071e6..73beb1f24 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -32,12 +32,13 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/fake" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" + pooltuil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -87,12 +88,12 @@ func TestPool(t *testing.T) { Build() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) ds := NewDatastore(context.Background(), pmf, 0) - _ = ds.PoolSet(context.Background(), fakeClient, tt.inferencePool) + _ = ds.PoolSet(context.Background(), fakeClient, pooltuil.InferencePoolToEndpointPool(tt.inferencePool)) gotPool, gotErr := ds.PoolGet() if diff := cmp.Diff(tt.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" { t.Errorf("Unexpected error diff (+got/-want): %s", diff) } - if diff := cmp.Diff(tt.wantPool, gotPool); diff != "" { + if diff := cmp.Diff(pooltuil.InferencePoolToEndpointPool(tt.wantPool), gotPool); diff != "" { t.Errorf("Unexpected pool diff (+got/-want): %s", diff) } gotSynced := ds.PoolHasSynced() @@ -328,7 +329,7 @@ func TestMetrics(t *testing.T) { Build() pmf := backendmetrics.NewPodMetricsFactory(test.pmc, time.Millisecond) ds := NewDatastore(ctx, pmf, 0) - _ = ds.PoolSet(ctx, fakeClient, inferencePool) + _ = ds.PoolSet(ctx, fakeClient, pooltuil.InferencePoolToEndpointPool(inferencePool)) for _, pod := range test.storePods { ds.PodUpdateOrAddIfNotExist(pod) } @@ -397,7 +398,7 @@ func TestPods(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) ds := NewDatastore(t.Context(), pmf, 0) fakeClient := fake.NewFakeClient() - if err := ds.PoolSet(ctx, fakeClient, inferencePool); err != nil { + if err := ds.PoolSet(ctx, fakeClient, pooltuil.InferencePoolToEndpointPool(inferencePool)); err != nil { t.Error(err) } for _, pod := range test.existingPods { @@ -581,7 +582,7 @@ func TestPodInfo(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) ds := NewDatastore(t.Context(), pmf, 0) fakeClient := fake.NewFakeClient() - if err := ds.PoolSet(ctx, fakeClient, test.pool); err != nil { + if err := ds.PoolSet(ctx, fakeClient, pooltuil.InferencePoolToEndpointPool(test.pool)); err != nil { t.Error(err) } for _, pod := range test.existingPods { diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 5296d49a4..a67a1ea6f 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -29,10 +29,10 @@ import ( "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "sigs.k8s.io/controller-runtime/pkg/log" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" @@ -61,7 +61,7 @@ type Director interface { } type Datastore interface { - PoolGet() (*v1.InferencePool, error) + PoolGet() (*datalayer.EndpointPool, error) } // Server implements the Envoy external processing server. diff --git a/pkg/epp/metrics/collectors/inference_pool_test.go b/pkg/epp/metrics/collectors/inference_pool_test.go index af2923e50..20ab69fd7 100644 --- a/pkg/epp/metrics/collectors/inference_pool_test.go +++ b/pkg/epp/metrics/collectors/inference_pool_test.go @@ -27,11 +27,12 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/component-base/metrics/testutil" - "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/fake" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + poolutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" ) var ( @@ -68,13 +69,6 @@ func TestMetricsCollected(t *testing.T) { }, } pmf := backendmetrics.NewPodMetricsFactory(pmc, time.Millisecond) - ds := datastore.NewDatastore(context.Background(), pmf, 0) - - scheme := runtime.NewScheme() - fakeClient := fake.NewClientBuilder(). - WithScheme(scheme). - Build() - inferencePool := &v1.InferencePool{ ObjectMeta: metav1.ObjectMeta{ Name: "test-pool", @@ -83,7 +77,14 @@ func TestMetricsCollected(t *testing.T) { TargetPorts: []v1.Port{{Number: v1.PortNumber(int32(8000))}}, }, } - _ = ds.PoolSet(context.Background(), fakeClient, inferencePool) + ds := datastore.NewDatastore(context.Background(), pmf, 0) + + scheme := runtime.NewScheme() + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() + + _ = ds.PoolSet(context.Background(), fakeClient, poolutil.InferencePoolToEndpointPool(inferencePool)) _ = ds.PodUpdateOrAddIfNotExist(pod1) time.Sleep(1 * time.Second) diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 59c8976cd..e005e1d4c 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -23,9 +23,9 @@ import ( "github.com/prometheus/client_golang/prometheus" compbasemetrics "k8s.io/component-base/metrics" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/metrics" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" metricsutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/metrics" ) diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index ecd52d90c..c4f4f1c1b 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -27,8 +27,6 @@ import ( "time" "sigs.k8s.io/controller-runtime/pkg/log" - - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -49,7 +47,7 @@ const ( // Datastore defines the interface required by the Director. type Datastore interface { - PoolGet() (*v1.InferencePool, error) + PoolGet() (*datalayer.EndpointPool, error) ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics } diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index e705beefd..f361303c8 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -32,8 +32,8 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/fake" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" @@ -46,6 +46,7 @@ import ( schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + poolutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -89,7 +90,7 @@ type mockDatastore struct { pods []backendmetrics.PodMetrics } -func (ds *mockDatastore) PoolGet() (*v1.InferencePool, error) { +func (ds *mockDatastore) PoolGet() (*datalayer.EndpointPool, error) { return nil, nil } func (ds *mockDatastore) ObjectiveGet(_ string) *v1alpha2.InferenceObjective { @@ -190,14 +191,6 @@ func TestDirector_HandleRequest(t *testing.T) { CreationTimestamp(metav1.Unix(1000, 0)). Priority(1). ObjRef() - - // Datastore setup - pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - ds := datastore.NewDatastore(t.Context(), pmf, 0) - ds.ObjectiveSet(ioFoodReview) - ds.ObjectiveSet(ioFoodReviewResolve) - ds.ObjectiveSet(ioFoodReviewSheddable) - pool := &v1.InferencePool{ ObjectMeta: metav1.ObjectMeta{Name: "test-pool", Namespace: "default"}, Spec: v1.InferencePoolSpec{ @@ -210,10 +203,18 @@ func TestDirector_HandleRequest(t *testing.T) { }, } + // Datastore setup + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := datastore.NewDatastore(t.Context(), pmf, 0) + ds.ObjectiveSet(ioFoodReview) + ds.ObjectiveSet(ioFoodReviewResolve) + ds.ObjectiveSet(ioFoodReviewSheddable) + scheme := runtime.NewScheme() _ = clientgoscheme.AddToScheme(scheme) fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() - if err := ds.PoolSet(ctx, fakeClient, pool); err != nil { + + if err := ds.PoolSet(ctx, fakeClient, poolutil.InferencePoolToEndpointPool(pool)); err != nil { t.Fatalf("Error while setting inference pool: %v", err) } @@ -754,8 +755,9 @@ func TestGetRandomPod(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Millisecond) + endpointPool := poolutil.InferencePoolToEndpointPool(pool) ds := datastore.NewDatastore(t.Context(), pmf, 0) - err := ds.PoolSet(t.Context(), fakeClient, pool) + err := ds.PoolSet(t.Context(), fakeClient, endpointPool) if err != nil { t.Errorf("unexpected error setting pool: %s", err) } diff --git a/pkg/epp/server/controller_manager.go b/pkg/epp/server/controller_manager.go index 47e4f12d4..c82b0bcb9 100644 --- a/pkg/epp/server/controller_manager.go +++ b/pkg/epp/server/controller_manager.go @@ -25,12 +25,12 @@ import ( utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/manager" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" @@ -45,7 +45,7 @@ func init() { } // defaultManagerOptions returns the default options used to create the manager. -func defaultManagerOptions(gknn common.GKNN, metricsServerOptions metricsserver.Options) (ctrl.Options, error) { +func defaultManagerOptions(disableK8sCrdReconcile bool, gknn common.GKNN, metricsServerOptions metricsserver.Options) (ctrl.Options, error) { opt := ctrl.Options{ Scheme: scheme, Cache: cache.Options{ @@ -55,38 +55,38 @@ func defaultManagerOptions(gknn common.GKNN, metricsServerOptions metricsserver. gknn.Namespace: {}, }, }, - &v1alpha2.InferenceObjective{}: { - Namespaces: map[string]cache.Config{ - gknn.Namespace: {}, - }, - }, }, }, Metrics: metricsServerOptions, } - switch gknn.Group { - case v1alpha2.GroupName: - opt.Cache.ByObject[&v1alpha2.InferencePool{}] = cache.ByObject{ - Namespaces: map[string]cache.Config{gknn.Namespace: {FieldSelector: fields.SelectorFromSet(fields.Set{ - "metadata.name": gknn.Name, - })}}, - } - case v1.GroupName: - opt.Cache.ByObject[&v1.InferencePool{}] = cache.ByObject{ - Namespaces: map[string]cache.Config{gknn.Namespace: {FieldSelector: fields.SelectorFromSet(fields.Set{ - "metadata.name": gknn.Name, - })}}, + if !disableK8sCrdReconcile { + opt.Cache.ByObject[&v1alpha2.InferenceObjective{}] = cache.ByObject{Namespaces: map[string]cache.Config{ + gknn.Namespace: {}, + }} + switch gknn.Group { + case v1alpha2.GroupName: + opt.Cache.ByObject[&v1alpha2.InferencePool{}] = cache.ByObject{ + Namespaces: map[string]cache.Config{gknn.Namespace: {FieldSelector: fields.SelectorFromSet(fields.Set{ + "metadata.name": gknn.Name, + })}}, + } + case v1.GroupName: + opt.Cache.ByObject[&v1.InferencePool{}] = cache.ByObject{ + Namespaces: map[string]cache.Config{gknn.Namespace: {FieldSelector: fields.SelectorFromSet(fields.Set{ + "metadata.name": gknn.Name, + })}}, + } + default: + return ctrl.Options{}, fmt.Errorf("unknown group: %s", gknn.Group) } - default: - return ctrl.Options{}, fmt.Errorf("unknown group: %s", gknn.Group) } return opt, nil } // NewDefaultManager creates a new controller manager with default configuration. -func NewDefaultManager(gknn common.GKNN, restConfig *rest.Config, metricsServerOptions metricsserver.Options, leaderElectionEnabled bool) (ctrl.Manager, error) { - opt, err := defaultManagerOptions(gknn, metricsServerOptions) +func NewDefaultManager(disableK8sCrdReconcile bool, gknn common.GKNN, restConfig *rest.Config, metricsServerOptions metricsserver.Options, leaderElectionEnabled bool) (ctrl.Manager, error) { + opt, err := defaultManagerOptions(disableK8sCrdReconcile, gknn, metricsServerOptions) if err != nil { return nil, fmt.Errorf("failed to create controller manager options: %v", err) } diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index c3037175e..e43d84923 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -30,9 +30,9 @@ import ( healthgrpc "google.golang.org/grpc/health/grpc_health_v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/manager" - "sigs.k8s.io/gateway-api-inference-extension/internal/runnable" tlsutil "sigs.k8s.io/gateway-api-inference-extension/internal/tls" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" @@ -48,8 +48,8 @@ import ( // ExtProcServerRunner provides methods to manage an external process server. type ExtProcServerRunner struct { GrpcPort int - PoolNamespacedName types.NamespacedName - PoolGKNN common.GKNN + GKNN common.GKNN + DisableK8sCrdReconcile bool Datastore datastore.Datastore SecureServing bool HealthChecking bool @@ -91,7 +91,7 @@ const ( // NewDefaultExtProcServerRunner creates a runner with default values. // Note: Dependencies like Datastore, Scheduler, SD need to be set separately. func NewDefaultExtProcServerRunner() *ExtProcServerRunner { - poolGKNN := common.GKNN{ + gknn := common.GKNN{ NamespacedName: types.NamespacedName{Name: DefaultPoolName, Namespace: DefaultPoolNamespace}, GroupKind: schema.GroupKind{ Group: DefaultPoolGroup, @@ -100,8 +100,8 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { } return &ExtProcServerRunner{ GrpcPort: DefaultGrpcPort, - PoolNamespacedName: types.NamespacedName{Name: DefaultPoolName, Namespace: DefaultPoolNamespace}, - PoolGKNN: poolGKNN, + GKNN: gknn, + DisableK8sCrdReconcile: false, SecureServing: DefaultSecureServing, HealthChecking: DefaultHealthChecking, RefreshPrometheusMetricsInterval: DefaultRefreshPrometheusMetricsInterval, @@ -113,20 +113,22 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { // SetupWithManager sets up the runner with the given manager. func (r *ExtProcServerRunner) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { // Create the controllers and register them with the manager - if err := (&controller.InferencePoolReconciler{ - Datastore: r.Datastore, - Reader: mgr.GetClient(), - PoolGKNN: r.PoolGKNN, - }).SetupWithManager(mgr); err != nil { - return fmt.Errorf("failed setting up InferencePoolReconciler: %w", err) - } + if !r.DisableK8sCrdReconcile { + if err := (&controller.InferencePoolReconciler{ + Datastore: r.Datastore, + Reader: mgr.GetClient(), + PoolGKNN: r.GKNN, + }).SetupWithManager(mgr); err != nil { + return fmt.Errorf("failed setting up InferencePoolReconciler: %w", err) + } - if err := (&controller.InferenceObjectiveReconciler{ - Datastore: r.Datastore, - Reader: mgr.GetClient(), - PoolGKNN: r.PoolGKNN, - }).SetupWithManager(ctx, mgr); err != nil { - return fmt.Errorf("failed setting up InferenceObjectiveReconciler: %w", err) + if err := (&controller.InferenceObjectiveReconciler{ + Datastore: r.Datastore, + Reader: mgr.GetClient(), + PoolGKNN: r.GKNN, + }).SetupWithManager(ctx, mgr); err != nil { + return fmt.Errorf("failed setting up InferenceObjectiveReconciler: %w", err) + } } if err := (&controller.PodReconciler{ diff --git a/pkg/epp/server/runserver_test.go b/pkg/epp/server/runserver_test.go index b02688c58..172928865 100644 --- a/pkg/epp/server/runserver_test.go +++ b/pkg/epp/server/runserver_test.go @@ -20,7 +20,6 @@ import ( "testing" "sigs.k8s.io/controller-runtime/pkg/manager" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) diff --git a/pkg/epp/util/pool/pool.go b/pkg/epp/util/pool/pool.go new file mode 100644 index 000000000..67fd9fef8 --- /dev/null +++ b/pkg/epp/util/pool/pool.go @@ -0,0 +1,89 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pool + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" + v1alpha2 "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" +) + +func InferencePoolToEndpointPool(inferencePool *v1.InferencePool) *datalayer.EndpointPool { + if inferencePool == nil { + return nil + } + targetPorts := make([]int, 0, len(inferencePool.Spec.TargetPorts)) + for _, p := range inferencePool.Spec.TargetPorts { + targetPorts = append(targetPorts, int(p.Number)) + + } + selector := make(map[string]string, len(inferencePool.Spec.Selector.MatchLabels)) + for k, v := range inferencePool.Spec.Selector.MatchLabels { + selector[string(k)] = string(v) + } + endpointPool := &datalayer.EndpointPool{ + Selector: selector, + TargetPorts: targetPorts, + Namespace: inferencePool.Namespace, + Name: inferencePool.Name, + } + return endpointPool +} + +func AlphaInferencePoolToEndpointPool(inferencePool *v1alpha2.InferencePool) *datalayer.EndpointPool { + targetPorts := []int{int(inferencePool.Spec.TargetPortNumber)} + selector := make(map[string]string, len(inferencePool.Spec.Selector)) + for k, v := range inferencePool.Spec.Selector { + selector[string(k)] = string(v) + } + + endpointPool := &datalayer.EndpointPool{ + TargetPorts: targetPorts, + Selector: selector, + Namespace: inferencePool.Namespace, + Name: inferencePool.Name, + } + return endpointPool +} + +func EndpointPoolToInferencePool(endpointPool *datalayer.EndpointPool) *v1.InferencePool { + targetPorts := make([]v1.Port, 0, len(endpointPool.TargetPorts)) + for _, p := range endpointPool.TargetPorts { + targetPorts = append(targetPorts, v1.Port{Number: v1.PortNumber(p)}) + } + labels := make(map[v1.LabelKey]v1.LabelValue, len(endpointPool.Selector)) + for k, v := range endpointPool.Selector { + labels[v1.LabelKey(k)] = v1.LabelValue(v) + } + + inferencePool := &v1.InferencePool{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "inference.networking.k8s.io/v1", + Kind: "InferencePool", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: endpointPool.Name, + Namespace: endpointPool.Namespace, + }, + Spec: v1.InferencePoolSpec{ + Selector: v1.LabelSelector{MatchLabels: labels}, + TargetPorts: targetPorts, + }, + } + return inferencePool +} diff --git a/pkg/epp/util/testing/wrappers.go b/pkg/epp/util/testing/wrappers.go index 7621bff96..3e3380519 100644 --- a/pkg/epp/util/testing/wrappers.go +++ b/pkg/epp/util/testing/wrappers.go @@ -19,8 +19,12 @@ package testing import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/common" ) // PodWrapper wraps a Pod. @@ -219,6 +223,19 @@ func (m *InferencePoolWrapper) ObjRef() *v1.InferencePool { return &m.InferencePool } +func (m *InferencePoolWrapper) ToGKNN() common.GKNN { + return common.GKNN{ + NamespacedName: types.NamespacedName{ + Name: m.Name, + Namespace: m.ObjectMeta.Namespace, + }, + GroupKind: schema.GroupKind{ + Group: "inference.networking.k8s.io", + Kind: "InferencePool", + }, + } +} + // AlphaInferencePoolWrapper wraps an group "inference.networking.x-k8s.io" InferencePool. type AlphaInferencePoolWrapper struct { v1alpha2.InferencePool diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 9ce4fec64..f1f32d58b 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -48,6 +48,7 @@ import ( k8syaml "k8s.io/apimachinery/pkg/util/yaml" clientgoscheme "k8s.io/client-go/kubernetes/scheme" metricsutils "k8s.io/component-base/metrics/testutil" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/cache" k8sclient "sigs.k8s.io/controller-runtime/pkg/client" @@ -56,8 +57,6 @@ import ( crmetrics "sigs.k8s.io/controller-runtime/pkg/metrics" "sigs.k8s.io/controller-runtime/pkg/metrics/filters" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" - "sigs.k8s.io/yaml" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" @@ -79,6 +78,7 @@ import ( requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" epptestutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" integrationutils "sigs.k8s.io/gateway-api-inference-extension/test/integration" + "sigs.k8s.io/yaml" ) const ( @@ -1170,10 +1170,11 @@ func BeforeSuite() func() { serverRunner.TestPodMetricsClient = &backendmetrics.FakePodMetricsClient{} pmf := backendmetrics.NewPodMetricsFactory(serverRunner.TestPodMetricsClient, 10*time.Millisecond) // Adjust from defaults - serverRunner.PoolGKNN = common.GKNN{ + serverRunner.GKNN = common.GKNN{ NamespacedName: types.NamespacedName{Namespace: testNamespace, Name: testPoolName}, GroupKind: schema.GroupKind{Group: v1.GroupVersion.Group, Kind: "InferencePool"}, } + serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf, 0) kvCacheUtilizationScorer := scorer.NewKVCacheUtilizationScorer() diff --git a/test/utils/server.go b/test/utils/server.go index 9cf907d29..f46ed6f79 100644 --- a/test/utils/server.go +++ b/test/utils/server.go @@ -30,13 +30,14 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + pooltuil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -72,7 +73,7 @@ func PrepareForTestStreamingServer(objectives []*v1alpha2.InferenceObjective, po Build() pool := testutil.MakeInferencePool(poolName).Namespace(namespace).ObjRef() pool.Spec.TargetPorts = []v1.Port{{Number: v1.PortNumber(poolPort)}} - _ = ds.PoolSet(context.Background(), fakeClient, pool) + _ = ds.PoolSet(context.Background(), fakeClient, pooltuil.InferencePoolToEndpointPool(pool)) return ctx, cancel, ds, pmc }