diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 2842704bb..7163346c1 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -172,7 +172,7 @@ func (r *Runner) Run(ctx context.Context) error { if err != nil { return err } - datastore := datastore.NewDatastore(ctx, epf) + datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort)) // --- Setup Metrics Server --- customCollectors := []prometheus.Collector{collectors.NewInferencePoolMetricsCollector(datastore)} @@ -393,7 +393,6 @@ func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{ MetricMapping: mapping, - ModelServerMetricsPort: int32(*modelServerMetricsPort), ModelServerMetricsPath: *modelServerMetricsPath, ModelServerMetricsScheme: *modelServerMetricsScheme, Client: metricsHttpClient, @@ -408,7 +407,6 @@ func setupDatalayer() (datalayer.EndpointFactory, error) { // this (and registering the sources with the endpoint factory) should // be moved accordingly. source := dlmetrics.NewDataSource(*modelServerMetricsScheme, - int32(*modelServerMetricsPort), // start with (optional) command line port value *modelServerMetricsPath, *modelServerMetricsHttpsInsecureSkipVerify, nil) diff --git a/pkg/epp/backend/metrics/fake.go b/pkg/epp/backend/metrics/fake.go index 4e0687ff0..613ebf5ec 100644 --- a/pkg/epp/backend/metrics/fake.go +++ b/pkg/epp/backend/metrics/fake.go @@ -22,7 +22,6 @@ import ( "sync" "time" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" @@ -49,8 +48,8 @@ func (fpm *FakePodMetrics) GetMetrics() *MetricsState { return fpm.Metrics } -func (fpm *FakePodMetrics) UpdatePod(pod *corev1.Pod) { - fpm.Pod = toInternalPod(pod) +func (fpm *FakePodMetrics) UpdatePod(pod *datalayer.PodInfo) { + fpm.Pod = pod } func (*FakePodMetrics) Put(string, datalayer.Cloneable) {} @@ -69,7 +68,7 @@ type FakePodMetricsClient struct { Res map[types.NamespacedName]*MetricsState } -func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, _ int32) (*MetricsState, error) { +func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) { f.errMu.RLock() err, ok := f.Err[pod.NamespacedName] f.errMu.RUnlock() diff --git a/pkg/epp/backend/metrics/metrics.go b/pkg/epp/backend/metrics/metrics.go index e10098fa4..305229ef8 100644 --- a/pkg/epp/backend/metrics/metrics.go +++ b/pkg/epp/backend/metrics/metrics.go @@ -40,7 +40,6 @@ const ( type PodMetricsClientImpl struct { MetricMapping *MetricMapping - ModelServerMetricsPort int32 ModelServerMetricsPath string ModelServerMetricsScheme string @@ -48,8 +47,8 @@ type PodMetricsClientImpl struct { } // FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an updated one. -func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, port int32) (*MetricsState, error) { - url := p.getMetricEndpoint(pod, port) +func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) { + url := p.getMetricEndpoint(pod) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) @@ -74,11 +73,8 @@ func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Po return p.promToPodMetrics(metricFamilies, existing) } -func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod, targetPortNumber int32) string { - if p.ModelServerMetricsPort == 0 { - p.ModelServerMetricsPort = targetPortNumber - } - return fmt.Sprintf("%s://%s:%d%s", p.ModelServerMetricsScheme, pod.Address, p.ModelServerMetricsPort, p.ModelServerMetricsPath) +func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod) string { + return fmt.Sprintf("%s://%s:%d%s", p.ModelServerMetricsScheme, pod.GetIPAddress(), pod.GetMetricsPort(), p.ModelServerMetricsPath) } // promToPodMetrics updates internal pod metrics with scraped Prometheus metrics. diff --git a/pkg/epp/backend/metrics/metrics_test.go b/pkg/epp/backend/metrics/metrics_test.go index 2dd8ca5dd..81382312f 100644 --- a/pkg/epp/backend/metrics/metrics_test.go +++ b/pkg/epp/backend/metrics/metrics_test.go @@ -489,7 +489,9 @@ func TestPromToPodMetrics(t *testing.T) { func TestFetchMetrics(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) pod := &backend.Pod{ - Address: "127.0.0.1", + Address: "127.0.0.1", + Port: 9999, + MetricsPort: 9999, NamespacedName: types.NamespacedName{ Namespace: "test", Name: "pod", @@ -499,12 +501,11 @@ func TestFetchMetrics(t *testing.T) { // No MetricMapping needed for this basic test p := &PodMetricsClientImpl{ ModelServerMetricsScheme: "http", - ModelServerMetricsPort: 9999, ModelServerMetricsPath: "/metrics", Client: http.DefaultClient, } - _, err := p.FetchMetrics(ctx, pod, existing, 9999) // Use a port that's unlikely to be in use + _, err := p.FetchMetrics(ctx, pod, existing) // Use a port that's unlikely to be in use if err == nil { t.Errorf("FetchMetrics() expected error, got nil") } diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go index da66a97ed..a1114aecf 100644 --- a/pkg/epp/backend/metrics/pod_metrics.go +++ b/pkg/epp/backend/metrics/pod_metrics.go @@ -24,8 +24,6 @@ import ( "time" "github.com/go-logr/logr" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" @@ -51,7 +49,7 @@ type podMetrics struct { } type PodMetricsClient interface { - FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, port int32) (*MetricsState, error) + FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) } func (pm *podMetrics) String() string { @@ -66,23 +64,8 @@ func (pm *podMetrics) GetMetrics() *MetricsState { return pm.metrics.Load() } -func (pm *podMetrics) UpdatePod(pod *corev1.Pod) { - pm.pod.Store(toInternalPod(pod)) -} - -func toInternalPod(pod *corev1.Pod) *backend.Pod { - labels := make(map[string]string, len(pod.GetLabels())) - for key, value := range pod.GetLabels() { - labels[key] = value - } - return &backend.Pod{ - NamespacedName: types.NamespacedName{ - Name: pod.Name, - Namespace: pod.Namespace, - }, - Address: pod.Status.PodIP, - Labels: labels, - } +func (pm *podMetrics) UpdatePod(pod *datalayer.PodInfo) { + pm.pod.Store(pod) } // start starts a goroutine exactly once to periodically update metrics. The goroutine will be @@ -110,17 +93,9 @@ func (pm *podMetrics) startRefreshLoop(ctx context.Context) { } func (pm *podMetrics) refreshMetrics() error { - pool, err := pm.ds.PoolGet() - if err != nil { - // No inference pool or not initialize. - return err - } ctx, cancel := context.WithTimeout(context.Background(), fetchMetricsTimeout) defer cancel() - if len(pool.Spec.TargetPorts) != 1 { - return fmt.Errorf("expected 1 target port, got %d", len(pool.Spec.TargetPorts)) - } - updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics(), int32(pool.Spec.TargetPorts[0].Number)) + updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics()) if err != nil { pm.logger.V(logutil.TRACE).Info("Failed to refreshed metrics:", "err", err) } diff --git a/pkg/epp/backend/metrics/pod_metrics_test.go b/pkg/epp/backend/metrics/pod_metrics_test.go index 9a0e1a6fc..b0297cd1e 100644 --- a/pkg/epp/backend/metrics/pod_metrics_test.go +++ b/pkg/epp/backend/metrics/pod_metrics_test.go @@ -23,19 +23,19 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" ) var ( - pod1 = &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pod1", + pod1Info = &datalayer.PodInfo{ + NamespacedName: types.NamespacedName{ + Name: "pod1-rank-0", Namespace: "default", }, + PodName: "pod1", } initial = &MetricsState{ WaitingQueueSize: 0, @@ -65,12 +65,11 @@ func TestMetricsRefresh(t *testing.T) { pmf := NewPodMetricsFactory(pmc, time.Millisecond) // The refresher is initialized with empty metrics. - pm := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{}) + pm := pmf.NewEndpoint(ctx, pod1Info, &fakeDataStore{}) - namespacedName := types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace} // Use SetRes to simulate an update of metrics from the pod. // Verify that the metrics are updated. - pmc.SetRes(map[types.NamespacedName]*MetricsState{namespacedName: initial}) + pmc.SetRes(map[types.NamespacedName]*MetricsState{pod1Info.NamespacedName: initial}) condition := func(collect *assert.CollectT) { assert.True(collect, cmp.Equal(pm.GetMetrics(), initial, cmpopts.IgnoreFields(MetricsState{}, "UpdateTime"))) } @@ -80,7 +79,7 @@ func TestMetricsRefresh(t *testing.T) { // new update. pmf.ReleaseEndpoint(pm) time.Sleep(pmf.refreshMetricsInterval * 2 /* small buffer for robustness */) - pmc.SetRes(map[types.NamespacedName]*MetricsState{namespacedName: updated}) + pmc.SetRes(map[types.NamespacedName]*MetricsState{pod1Info.NamespacedName: updated}) // Still expect the same condition (no metrics update). assert.EventuallyWithT(t, condition, time.Second, time.Millisecond) } diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go index aadeb85cb..99f15a20f 100644 --- a/pkg/epp/backend/metrics/types.go +++ b/pkg/epp/backend/metrics/types.go @@ -22,7 +22,6 @@ import ( "sync" "time" - corev1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" @@ -53,8 +52,7 @@ type PodMetricsFactory struct { refreshMetricsInterval time.Duration } -func (f *PodMetricsFactory) NewEndpoint(parentCtx context.Context, in *corev1.Pod, ds datalayer.PoolInfo) PodMetrics { - pod := toInternalPod(in) +func (f *PodMetricsFactory) NewEndpoint(parentCtx context.Context, pod *datalayer.PodInfo, ds datalayer.PoolInfo) PodMetrics { pm := &podMetrics{ pmc: f.pmc, ds: ds, diff --git a/pkg/epp/controller/inferenceobjective_reconciler_test.go b/pkg/epp/controller/inferenceobjective_reconciler_test.go index de43d6e63..4ceff5d07 100644 --- a/pkg/epp/controller/inferenceobjective_reconciler_test.go +++ b/pkg/epp/controller/inferenceobjective_reconciler_test.go @@ -160,7 +160,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) { WithObjects(initObjs...). Build() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - ds := datastore.NewDatastore(t.Context(), pmf) + ds := datastore.NewDatastore(t.Context(), pmf, 0) for _, m := range test.objectivessInStore { ds.ObjectiveSet(m) } diff --git a/pkg/epp/controller/inferencepool_reconciler_test.go b/pkg/epp/controller/inferencepool_reconciler_test.go index 7f6938533..e4d8b5d42 100644 --- a/pkg/epp/controller/inferencepool_reconciler_test.go +++ b/pkg/epp/controller/inferencepool_reconciler_test.go @@ -113,14 +113,14 @@ func TestInferencePoolReconciler(t *testing.T) { ctx := context.Background() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - datastore := datastore.NewDatastore(ctx, pmf) - inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: datastore, PoolGKNN: gknn} + ds := datastore.NewDatastore(ctx, pmf, 0) + inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: ds, PoolGKNN: gknn} // Step 1: Inception, only ready pods matching pool1 are added to the store. if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffStore(datastore, diffStoreParams{wantPool: pool1, wantPods: []string{"pod1", "pod2"}}); diff != "" { + if diff := diffStore(ds, diffStoreParams{wantPool: pool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } @@ -138,7 +138,7 @@ func TestInferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffStore(datastore, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" { + if diff := diffStore(ds, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } @@ -153,7 +153,7 @@ func TestInferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffStore(datastore, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" { + if diff := diffStore(ds, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } @@ -167,7 +167,7 @@ func TestInferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffStore(datastore, diffStoreParams{wantPods: []string{}}); diff != "" { + if diff := diffStore(ds, diffStoreParams{wantPods: []string{}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } } @@ -258,14 +258,14 @@ func TestXInferencePoolReconciler(t *testing.T) { ctx := context.Background() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - datastore := datastore.NewDatastore(ctx, pmf) - inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: datastore, PoolGKNN: gknn} + ds := datastore.NewDatastore(ctx, pmf, 0) + inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: ds, PoolGKNN: gknn} // Step 1: Inception, only ready pods matching pool1 are added to the store. if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: pool1, wantPods: []string{"pod1", "pod2"}}); diff != "" { + if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: pool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } @@ -281,7 +281,7 @@ func TestXInferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" { + if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } @@ -296,7 +296,7 @@ func TestXInferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" { + if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } @@ -310,7 +310,7 @@ func TestXInferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPods: []string{}}); diff != "" { + if diff := xDiffStore(t, ds, xDiffStoreParams{wantPods: []string{}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } } diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index ce77b2cfd..0d255a057 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -23,7 +23,6 @@ import ( "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" @@ -53,7 +52,7 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R pod := &corev1.Pod{} if err := c.Get(ctx, req.NamespacedName, pod); err != nil { if apierrors.IsNotFound(err) { - c.Datastore.PodDelete(req.NamespacedName) + c.Datastore.PodRemove(req.Name) return ctrl.Result{}, nil } return ctrl.Result{}, fmt.Errorf("unable to get pod - %w", err) @@ -90,10 +89,9 @@ func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { } func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) { - namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} if !podutil.IsPodReady(pod) || !c.Datastore.PoolLabelsMatch(pod.Labels) { logger.V(logutil.DEBUG).Info("Pod removed or not added") - c.Datastore.PodDelete(namespacedName) + c.Datastore.PodRemove(pod.Name) } else { if c.Datastore.PodUpdateOrAddIfNotExist(pod) { logger.V(logutil.DEFAULT).Info("Pod added") diff --git a/pkg/epp/controller/pod_reconciler_test.go b/pkg/epp/controller/pod_reconciler_test.go index 5ceb3efdb..28f817310 100644 --- a/pkg/epp/controller/pod_reconciler_test.go +++ b/pkg/epp/controller/pod_reconciler_test.go @@ -196,7 +196,7 @@ func TestPodReconciler(t *testing.T) { Build() // Configure the initial state of the datastore. - store := datastore.NewDatastore(t.Context(), pmf) + store := datastore.NewDatastore(t.Context(), pmf, 0) _ = store.PoolSet(t.Context(), fakeClient, test.pool) for _, pod := range test.existingPods { store.PodUpdateOrAddIfNotExist(pod) @@ -213,7 +213,7 @@ func TestPodReconciler(t *testing.T) { var gotPods []*corev1.Pod for _, pm := range store.PodList(backendmetrics.AllPodsPredicate) { - pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().NamespacedName.Name, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().Address}} + pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().PodName, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().GetIPAddress()}} gotPods = append(gotPods, pod) } if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) { diff --git a/pkg/epp/datalayer/collector_test.go b/pkg/epp/datalayer/collector_test.go index 2d47de30a..0e3b9151b 100644 --- a/pkg/epp/datalayer/collector_test.go +++ b/pkg/epp/datalayer/collector_test.go @@ -24,8 +24,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/mocks" ) @@ -45,14 +44,12 @@ func (d *DummySource) Collect(ctx context.Context, ep Endpoint) error { func defaultEndpoint() Endpoint { ms := NewEndpoint() - pod := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ + pod := &PodInfo{ + NamespacedName: types.NamespacedName{ Name: "pod-name", Namespace: "default", }, - Status: corev1.PodStatus{ - PodIP: "1.2.3.4", - }, + Address: "1.2.3.4:5678", } ms.UpdatePod(pod) return ms diff --git a/pkg/epp/datalayer/endpoint.go b/pkg/epp/datalayer/endpoint.go index 2a728864c..74c11905e 100644 --- a/pkg/epp/datalayer/endpoint.go +++ b/pkg/epp/datalayer/endpoint.go @@ -19,14 +19,12 @@ package datalayer import ( "fmt" "sync/atomic" - - corev1 "k8s.io/api/core/v1" ) // EndpointPodState allows management of the Pod related attributes. type EndpointPodState interface { GetPod() *PodInfo - UpdatePod(*corev1.Pod) + UpdatePod(*PodInfo) } // EndpointMetricsState allows management of the Metrics related attributes. @@ -67,8 +65,8 @@ func (srv *ModelServer) GetPod() *PodInfo { return srv.pod.Load() } -func (srv *ModelServer) UpdatePod(pod *corev1.Pod) { - srv.pod.Store(ToPodInfo(pod)) +func (srv *ModelServer) UpdatePod(pod *PodInfo) { + srv.pod.Store(pod) } func (srv *ModelServer) GetMetrics() *Metrics { diff --git a/pkg/epp/datalayer/factory.go b/pkg/epp/datalayer/factory.go index eca7697e5..989527c6c 100644 --- a/pkg/epp/datalayer/factory.go +++ b/pkg/epp/datalayer/factory.go @@ -21,7 +21,6 @@ import ( "sync" "time" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" @@ -45,7 +44,7 @@ type PoolInfo interface { // providing methods to allocate and retire endpoints. This can potentially be used for // pooled memory or other management chores in the implementation. type EndpointFactory interface { - NewEndpoint(parent context.Context, inpod *corev1.Pod, poolinfo PoolInfo) Endpoint + NewEndpoint(parent context.Context, inpod *PodInfo, poolinfo PoolInfo) Endpoint ReleaseEndpoint(ep Endpoint) } @@ -70,8 +69,8 @@ func NewEndpointFactory(sources []DataSource, refreshMetricsInterval time.Durati // NewEndpoint implements EndpointFactory.NewEndpoint. // Creates a new endpoint and starts its associated collector with its own ticker. // Guards against multiple concurrent calls for the same endpoint. -func (lc *EndpointLifecycle) NewEndpoint(parent context.Context, inpod *corev1.Pod, _ PoolInfo) Endpoint { - key := types.NamespacedName{Namespace: inpod.Namespace, Name: inpod.Name} +func (lc *EndpointLifecycle) NewEndpoint(parent context.Context, inpod *PodInfo, _ PoolInfo) Endpoint { + key := types.NamespacedName{Namespace: inpod.GetNamespacedName().Namespace, Name: inpod.GetNamespacedName().Name} logger := log.FromContext(parent).WithValues("pod", key) if _, ok := lc.collectors.Load(key); ok { diff --git a/pkg/epp/datalayer/metrics/datasource.go b/pkg/epp/datalayer/metrics/datasource.go index 7dcdc97ba..5f5b2a10e 100644 --- a/pkg/epp/datalayer/metrics/datasource.go +++ b/pkg/epp/datalayer/metrics/datasource.go @@ -25,7 +25,6 @@ import ( "net/url" "strconv" "sync" - "sync/atomic" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" ) @@ -37,9 +36,8 @@ const ( // DataSource is a Model Server Protocol (MSP) compliant metrics data source, // returning Prometheus formatted metrics for an endpoint. type DataSource struct { - metricsScheme string // scheme to use in metrics URL - metricsPort atomic.Pointer[string] // target port to use in metrics URL - metricsPath string // path to use in metrics URL + metricsScheme string // scheme to use in metrics URL + metricsPath string // path to use in metrics URL client Client // client (e.g. a wrapped http.Client) used to get metrics extractors sync.Map // key: name, value: extractor @@ -49,7 +47,7 @@ type DataSource struct { // the provided client factory. If ClientFactory is nil, a default factory is used. // The Scheme, port and path are command line options. It should be noted that // a port value of zero is set if the command line is unspecified. -func NewDataSource(metricsScheme string, metricsPort int32, metricsPath string, skipCertVerification bool, cl Client) *DataSource { +func NewDataSource(metricsScheme string, metricsPath string, skipCertVerification bool, cl Client) *DataSource { if metricsScheme == "https" { httpsTransport := baseTransport.Clone() httpsTransport.TLSClientConfig = &tls.Config{ @@ -67,25 +65,9 @@ func NewDataSource(metricsScheme string, metricsPort int32, metricsPath string, metricsPath: metricsPath, client: cl, } - dataSrc.SetPort(metricsPort) return dataSrc } -// SetPort updates the port used for metrics scraping. -// The port value can only be set once (i.e., if set by command line, -// do not overwrite from Pool.Spec). A port value of 0 (i.e., unspecified -// command line value) is ignored. -// TODO: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1398 -func (dataSrc *DataSource) SetPort(metricsPort int32) { - if dataSrc.metricsPort.Load() != nil { // do not overwrite - return - } - if metricsPort != 0 { // ignore zero value for port - port := strconv.Itoa(int(metricsPort)) - dataSrc.metricsPort.Store(&port) - } -} - // Name returns the metrics data source name. func (dataSrc *DataSource) Name() string { return DataSourceName @@ -132,7 +114,7 @@ func (dataSrc *DataSource) Collect(ctx context.Context, ep datalayer.Endpoint) e func (dataSrc *DataSource) getMetricsEndpoint(ep datalayer.Addressable) *url.URL { return &url.URL{ Scheme: dataSrc.metricsScheme, - Host: net.JoinHostPort(ep.GetIPAddress(), *dataSrc.metricsPort.Load()), + Host: net.JoinHostPort(ep.GetIPAddress(), strconv.Itoa(int(ep.GetMetricsPort()))), Path: dataSrc.metricsPath, } } diff --git a/pkg/epp/datalayer/podinfo.go b/pkg/epp/datalayer/podinfo.go index afd107bf9..f44435965 100644 --- a/pkg/epp/datalayer/podinfo.go +++ b/pkg/epp/datalayer/podinfo.go @@ -19,39 +19,27 @@ package datalayer import ( "fmt" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" ) // Addressable supports getting an IP address and a namespaced name. type Addressable interface { GetIPAddress() string + GetPort() int32 + GetMetricsPort() int32 GetNamespacedName() types.NamespacedName } // PodInfo represents the relevant Kubernetes Pod state of an inference server. type PodInfo struct { NamespacedName types.NamespacedName + PodName string Address string + Port int32 + MetricsPort int32 Labels map[string]string } -// ToPodInfo converts a Kubernetes API Pod to its internal representation. -func ToPodInfo(pod *corev1.Pod) *PodInfo { - labels := make(map[string]string, len(pod.GetLabels())) - for key, value := range pod.GetLabels() { - labels[key] = value - } - return &PodInfo{ - NamespacedName: types.NamespacedName{ - Name: pod.Name, - Namespace: pod.Namespace, - }, - Address: pod.Status.PodIP, - Labels: labels, - } -} - // String returns a string representation of the pod. func (p *PodInfo) String() string { if p == nil { @@ -75,8 +63,11 @@ func (p *PodInfo) Clone() *PodInfo { Name: p.NamespacedName.Name, Namespace: p.NamespacedName.Namespace, }, - Address: p.Address, - Labels: clonedLabels, + PodName: p.PodName, + Address: p.Address, + Port: p.Port, + MetricsPort: p.MetricsPort, + Labels: clonedLabels, } } @@ -89,3 +80,13 @@ func (p *PodInfo) GetNamespacedName() types.NamespacedName { func (p *PodInfo) GetIPAddress() string { return p.Address } + +// GetPort returns the Pod's inference port. +func (p *PodInfo) GetPort() int32 { + return p.Port +} + +// GetMetricsPort returns the pod's metrics port +func (p *PodInfo) GetMetricsPort() int32 { + return p.MetricsPort +} diff --git a/pkg/epp/datalayer/podinfo_test.go b/pkg/epp/datalayer/podinfo_test.go index 3a713e7a3..9bfd0ea38 100644 --- a/pkg/epp/datalayer/podinfo_test.go +++ b/pkg/epp/datalayer/podinfo_test.go @@ -55,13 +55,6 @@ var ( } ) -func TestToPodInfo(t *testing.T) { - podinfo := ToPodInfo(pod) - if diff := cmp.Diff(expected, podinfo); diff != "" { - t.Errorf("Unexpected output (-want +got): %v", diff) - } -} - func TestPodInfoClone(t *testing.T) { clone := expected.Clone() assert.NotSame(t, expected, clone) @@ -74,7 +67,17 @@ func TestPodInfoClone(t *testing.T) { } func TestPodInfoString(t *testing.T) { - podinfo := ToPodInfo(pod) + podinfo := PodInfo{ + NamespacedName: types.NamespacedName{ + Name: pod.Name, + Namespace: pod.Namespace, + }, + PodName: pod.Name, + Address: pod.Status.PodIP, + Port: 0, + MetricsPort: 0, + Labels: labels, + } s := podinfo.String() assert.Contains(t, s, name) diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 86204be26..8c65d24de 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "reflect" + "strconv" "sync" corev1 "k8s.io/api/core/v1" @@ -33,7 +34,6 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" - dlmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" ) @@ -63,18 +63,20 @@ type Datastore interface { PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool PodDelete(namespacedName types.NamespacedName) + PodRemove(podName string) // Clears the store state, happens when the pool gets deleted. Clear() } -func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory) Datastore { +func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory, modelServerMetricsPort int32) Datastore { store := &datastore{ - parentCtx: parentCtx, - poolAndObjectivesMu: sync.RWMutex{}, - objectives: make(map[string]*v1alpha2.InferenceObjective), - pods: &sync.Map{}, - epf: epFactory, + parentCtx: parentCtx, + poolAndObjectivesMu: sync.RWMutex{}, + objectives: make(map[string]*v1alpha2.InferenceObjective), + pods: &sync.Map{}, + modelServerMetricsPort: modelServerMetricsPort, + epf: epFactory, } return store } @@ -89,7 +91,10 @@ type datastore struct { objectives map[string]*v1alpha2.InferenceObjective // key: types.NamespacedName, value: backendmetrics.PodMetrics pods *sync.Map - epf datalayer.EndpointFactory + // modelServerMetricsPort metrics port from EPP command line argument + // used only if there is only one inference engine per pod + modelServerMetricsPort int32 + epf datalayer.EndpointFactory } func (ds *datastore) Clear() { @@ -117,11 +122,6 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, pool *v1 oldPool := ds.pool ds.pool = pool - if oldPool == nil || pool.Spec.TargetPorts[0] != oldPool.Spec.TargetPorts[0] { - if source, found := datalayer.GetNamedSource[*dlmetrics.DataSource](dlmetrics.DataSourceName); found { - source.SetPort(int32(pool.Spec.TargetPorts[0].Number)) - } - } if oldPool == nil || !reflect.DeepEqual(pool.Spec.Selector, oldPool.Spec.Selector) { logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", pool.Spec.Selector) // A full resync is required to address two cases: @@ -215,21 +215,49 @@ func (ds *datastore) PodList(predicate func(backendmetrics.PodMetrics) bool) []b } func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { - namespacedName := types.NamespacedName{ - Name: pod.Name, - Namespace: pod.Namespace, + if ds.pool == nil { + return true } - var pm backendmetrics.PodMetrics - existing, ok := ds.pods.Load(namespacedName) - if !ok { - pm = ds.epf.NewEndpoint(ds.parentCtx, pod, ds) - ds.pods.Store(namespacedName, pm) - } else { - pm = existing.(backendmetrics.PodMetrics) + + labels := make(map[string]string, len(pod.GetLabels())) + for key, value := range pod.GetLabels() { + labels[key] = value } - // Update pod properties if anything changed. - pm.UpdatePod(pod) - return ok + + pods := []*datalayer.PodInfo{} + for idx, port := range ds.pool.Spec.TargetPorts { + pods = append(pods, + &datalayer.PodInfo{ + NamespacedName: types.NamespacedName{ + Name: pod.Name + "-rank-" + strconv.Itoa(idx), + Namespace: pod.Namespace, + }, + PodName: pod.Name, + Address: pod.Status.PodIP, + Port: int32(port.Number), + MetricsPort: int32(port.Number), + Labels: labels, + }) + } + if len(pods) == 1 && ds.modelServerMetricsPort != 0 { + pods[0].MetricsPort = ds.modelServerMetricsPort + } + + result := true + for _, podInfo := range pods { + var pm backendmetrics.PodMetrics + existing, ok := ds.pods.Load(podInfo.NamespacedName) + if !ok { + pm = ds.epf.NewEndpoint(ds.parentCtx, podInfo, ds) + ds.pods.Store(podInfo.NamespacedName, pm) + result = false + } else { + pm = existing.(backendmetrics.PodMetrics) + } + // Update pod properties if anything changed. + pm.UpdatePod(podInfo) + } + return result } func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { @@ -239,6 +267,16 @@ func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { } } +func (ds *datastore) PodRemove(podName string) { + ds.pods.Range(func(k, v any) bool { + pm := v.(backendmetrics.PodMetrics) + if pm.GetPod().PodName == podName { + ds.PodDelete(pm.GetPod().NamespacedName) + } + return true + }) +} + func (ds *datastore) podResyncAll(ctx context.Context, reader client.Reader) error { logger := log.FromContext(ctx) podList := &corev1.PodList{} @@ -266,7 +304,7 @@ func (ds *datastore) podResyncAll(ctx context.Context, reader client.Reader) err // Remove pods that don't belong to the pool or not ready any more. ds.pods.Range(func(k, v any) bool { pm := v.(backendmetrics.PodMetrics) - if exist := activePods[pm.GetPod().NamespacedName.Name]; !exist { + if exist := activePods[pm.GetPod().PodName]; !exist { logger.V(logutil.VERBOSE).Info("Removing pod", "pod", pm.GetPod()) ds.PodDelete(pm.GetPod().NamespacedName) } diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 271c31ee7..f7b7ab106 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -35,6 +35,7 @@ import ( v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -83,21 +84,21 @@ func TestPool(t *testing.T) { WithScheme(scheme). Build() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - datastore := NewDatastore(context.Background(), pmf) - _ = datastore.PoolSet(context.Background(), fakeClient, tt.inferencePool) - gotPool, gotErr := datastore.PoolGet() + ds := NewDatastore(context.Background(), pmf, 0) + _ = ds.PoolSet(context.Background(), fakeClient, tt.inferencePool) + gotPool, gotErr := ds.PoolGet() if diff := cmp.Diff(tt.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" { t.Errorf("Unexpected error diff (+got/-want): %s", diff) } if diff := cmp.Diff(tt.wantPool, gotPool); diff != "" { t.Errorf("Unexpected pool diff (+got/-want): %s", diff) } - gotSynced := datastore.PoolHasSynced() + gotSynced := ds.PoolHasSynced() if diff := cmp.Diff(tt.wantSynced, gotSynced); diff != "" { t.Errorf("Unexpected synced diff (+got/-want): %s", diff) } if tt.labels != nil { - gotLabelsMatch := datastore.PoolLabelsMatch(tt.labels) + gotLabelsMatch := ds.PoolLabelsMatch(tt.labels) if diff := cmp.Diff(tt.wantLabelsMatch, gotLabelsMatch); diff != "" { t.Errorf("Unexpected labels match diff (+got/-want): %s", diff) } @@ -190,7 +191,7 @@ func TestObjective(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - ds := NewDatastore(t.Context(), pmf) + ds := NewDatastore(t.Context(), pmf, 0) for _, m := range test.existingModels { ds.ObjectiveSet(m) } @@ -241,13 +242,18 @@ var ( WaitingModels: map[string]int{}, } - pod1NamespacedName = types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace} - pod2NamespacedName = types.NamespacedName{Name: pod2.Name, Namespace: pod2.Namespace} + pod1NamespacedName = types.NamespacedName{Name: pod1.Name + "-rank-0", Namespace: pod1.Namespace} + pod2NamespacedName = types.NamespacedName{Name: pod2.Name + "-rank-0", Namespace: pod2.Namespace} inferencePool = &v1.InferencePool{ Spec: v1.InferencePoolSpec{ TargetPorts: []v1.Port{{Number: v1.PortNumber(int32(8000))}}, }, } + inferencePoolMultiTarget = &v1.InferencePool{ + Spec: v1.InferencePoolSpec{ + TargetPorts: []v1.Port{{Number: v1.PortNumber(int32(8000))}, {Number: v1.PortNumber(int32(8001))}}, + }, + } ) func TestMetrics(t *testing.T) { @@ -315,7 +321,7 @@ func TestMetrics(t *testing.T) { WithScheme(scheme). Build() pmf := backendmetrics.NewPodMetricsFactory(test.pmc, time.Millisecond) - ds := NewDatastore(ctx, pmf) + ds := NewDatastore(ctx, pmf, 0) _ = ds.PoolSet(ctx, fakeClient, inferencePool) for _, pod := range test.storePods { ds.PodUpdateOrAddIfNotExist(pod) @@ -340,14 +346,6 @@ func TestMetrics(t *testing.T) { } func TestPods(t *testing.T) { - updatedPod := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pod1", - }, - Spec: corev1.PodSpec{ - NodeName: "node-1", - }, - } tests := []struct { name string op func(ctx context.Context, ds Datastore) @@ -371,60 +369,226 @@ func TestPods(t *testing.T) { }, }, { - name: "Update existing pod, new field, should update", - existingPods: []*corev1.Pod{pod1}, - wantPods: []*corev1.Pod{updatedPod}, + name: "Delete the pod", + existingPods: []*corev1.Pod{pod1, pod2}, + wantPods: []*corev1.Pod{pod1}, op: func(ctx context.Context, ds Datastore) { - ds.PodUpdateOrAddIfNotExist(updatedPod) + ds.PodRemove(pod2.Name) }, }, { - name: "Update existing pod, no new fields, should not update", + name: "Delete the pod that doesn't exist", existingPods: []*corev1.Pod{pod1}, wantPods: []*corev1.Pod{pod1}, op: func(ctx context.Context, ds Datastore) { - incoming := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pod1", - Namespace: "default", + ds.PodRemove(pod2.Name) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := NewDatastore(t.Context(), pmf, 0) + fakeClient := fake.NewFakeClient() + if err := ds.PoolSet(ctx, fakeClient, inferencePool); err != nil { + t.Error(err) + } + for _, pod := range test.existingPods { + ds.PodUpdateOrAddIfNotExist(pod) + } + + test.op(ctx, ds) + var gotPods []*corev1.Pod + for _, pm := range ds.PodList(backendmetrics.AllPodsPredicate) { + pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().PodName, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().GetIPAddress()}} + gotPods = append(gotPods, pod) + } + if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) { + t.Errorf("got (%v) != want (%v);", gotPods, test.wantPods) + } + }) + } +} + +func TestPodInfo(t *testing.T) { + tests := []struct { + name string + op func(ctx context.Context, ds Datastore) + pool *v1.InferencePool + existingPods []*corev1.Pod + wantPodInfos []*datalayer.PodInfo + }{ + { + name: "Add new pod, no existing pods, should add", + existingPods: []*corev1.Pod{}, + wantPodInfos: []*datalayer.PodInfo{ + { + NamespacedName: types.NamespacedName{ + Name: pod1.Name + "-rank-0", + Namespace: pod1.Namespace, }, - } - ds.PodUpdateOrAddIfNotExist(incoming) + + PodName: pod1.Name, + Address: pod1.Status.PodIP, + Port: int32(inferencePool.Spec.TargetPorts[0].Number), + MetricsPort: int32(inferencePool.Spec.TargetPorts[0].Number), + Labels: map[string]string{}, + }, + }, + op: func(ctx context.Context, ds Datastore) { + ds.PodUpdateOrAddIfNotExist(pod1) }, + pool: inferencePool, }, { - name: "Delete the pod", - wantPods: []*corev1.Pod{pod1}, + name: "Add new pod, no existing pods, should add, multiple target ports", + existingPods: []*corev1.Pod{}, + wantPodInfos: []*datalayer.PodInfo{ + { + NamespacedName: types.NamespacedName{ + Name: pod1.Name + "-rank-0", + Namespace: pod1.Namespace, + }, + + PodName: pod1.Name, + Address: pod1.Status.PodIP, + Port: int32(inferencePoolMultiTarget.Spec.TargetPorts[0].Number), + MetricsPort: int32(inferencePoolMultiTarget.Spec.TargetPorts[0].Number), + Labels: map[string]string{}, + }, + { + NamespacedName: types.NamespacedName{ + Name: pod1.Name + "-rank-1", + Namespace: pod1.Namespace, + }, + + PodName: pod1.Name, + Address: pod1.Status.PodIP, + Port: int32(inferencePoolMultiTarget.Spec.TargetPorts[1].Number), + MetricsPort: int32(inferencePoolMultiTarget.Spec.TargetPorts[1].Number), + Labels: map[string]string{}, + }, + }, op: func(ctx context.Context, ds Datastore) { - ds.PodDelete(pod2NamespacedName) + ds.PodUpdateOrAddIfNotExist(pod1) }, + pool: inferencePoolMultiTarget, }, { - name: "Delete the pod that doesn't exist", + name: "Add new pod, with existing pods, should add, multiple target ports", existingPods: []*corev1.Pod{pod1}, - wantPods: []*corev1.Pod{pod1}, + wantPodInfos: []*datalayer.PodInfo{ + { + NamespacedName: types.NamespacedName{ + Name: pod1.Name + "-rank-0", + Namespace: pod1.Namespace, + }, + + PodName: pod1.Name, + Address: pod1.Status.PodIP, + Port: int32(inferencePoolMultiTarget.Spec.TargetPorts[0].Number), + MetricsPort: int32(inferencePoolMultiTarget.Spec.TargetPorts[0].Number), + Labels: map[string]string{}, + }, + { + NamespacedName: types.NamespacedName{ + Name: pod1.Name + "-rank-1", + Namespace: pod1.Namespace, + }, + + PodName: pod1.Name, + Address: pod1.Status.PodIP, + Port: int32(inferencePoolMultiTarget.Spec.TargetPorts[1].Number), + MetricsPort: int32(inferencePoolMultiTarget.Spec.TargetPorts[1].Number), + Labels: map[string]string{}, + }, + { + NamespacedName: types.NamespacedName{ + Name: pod2.Name + "-rank-0", + Namespace: pod2.Namespace, + }, + + PodName: pod2.Name, + Address: pod2.Status.PodIP, + Port: int32(inferencePoolMultiTarget.Spec.TargetPorts[0].Number), + MetricsPort: int32(inferencePoolMultiTarget.Spec.TargetPorts[0].Number), + Labels: map[string]string{}, + }, + { + NamespacedName: types.NamespacedName{ + Name: pod2.Name + "-rank-1", + Namespace: pod2.Namespace, + }, + + PodName: pod2.Name, + Address: pod2.Status.PodIP, + Port: int32(inferencePoolMultiTarget.Spec.TargetPorts[1].Number), + MetricsPort: int32(inferencePoolMultiTarget.Spec.TargetPorts[1].Number), + Labels: map[string]string{}, + }, + }, + op: func(ctx context.Context, ds Datastore) { + ds.PodUpdateOrAddIfNotExist(pod2) + }, + pool: inferencePoolMultiTarget, + }, + { + name: "Delete the pod, multiple target ports", + existingPods: []*corev1.Pod{pod1, pod2}, + wantPodInfos: []*datalayer.PodInfo{ + { + NamespacedName: types.NamespacedName{ + Name: pod1.Name + "-rank-0", + Namespace: pod1.Namespace, + }, + + PodName: pod1.Name, + Address: pod1.Status.PodIP, + Port: int32(inferencePoolMultiTarget.Spec.TargetPorts[0].Number), + MetricsPort: int32(inferencePoolMultiTarget.Spec.TargetPorts[0].Number), + Labels: map[string]string{}, + }, + { + NamespacedName: types.NamespacedName{ + Name: pod1.Name + "-rank-1", + Namespace: pod1.Namespace, + }, + + PodName: pod1.Name, + Address: pod1.Status.PodIP, + Port: int32(inferencePoolMultiTarget.Spec.TargetPorts[1].Number), + MetricsPort: int32(inferencePoolMultiTarget.Spec.TargetPorts[1].Number), + Labels: map[string]string{}, + }, + }, op: func(ctx context.Context, ds Datastore) { - ds.PodDelete(pod2NamespacedName) + ds.PodRemove(pod2.Name) }, + pool: inferencePoolMultiTarget, }, } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { ctx := context.Background() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - ds := NewDatastore(t.Context(), pmf) + ds := NewDatastore(t.Context(), pmf, 0) + fakeClient := fake.NewFakeClient() + if err := ds.PoolSet(ctx, fakeClient, test.pool); err != nil { + t.Error(err) + } for _, pod := range test.existingPods { ds.PodUpdateOrAddIfNotExist(pod) } test.op(ctx, ds) - var gotPods []*corev1.Pod + var gotPodInfos []*datalayer.PodInfo for _, pm := range ds.PodList(backendmetrics.AllPodsPredicate) { - pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().NamespacedName.Name, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().Address}} - gotPods = append(gotPods, pod) + gotPodInfos = append(gotPodInfos, pm.GetPod()) } - if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) { - t.Logf("got (%v) != want (%v);", gotPods, test.wantPods) + if diff := cmp.Diff(test.wantPodInfos, gotPodInfos, cmpopts.SortSlices(func(a, b *datalayer.PodInfo) bool { return a.NamespacedName.Name < b.NamespacedName.Name })); diff != "" { + t.Errorf("ConvertTo() mismatch (-want +got):\n%s", diff) } }) } diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 7f8122195..ba86fa580 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -17,7 +17,6 @@ limitations under the License. package handlers import ( - "fmt" "strconv" "time" @@ -42,14 +41,7 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP if pod == nil { return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"} } - pool, err := s.datastore.PoolGet() - if err != nil { - return err - } - if len(pool.Spec.TargetPorts) != 1 { - return fmt.Errorf("expected 1 target port, got %d", len(pool.Spec.TargetPorts)) - } - reqCtx.TargetEndpoint = pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPorts[0].Number)) + reqCtx.TargetEndpoint = pod.GetIPAddress() + ":" + strconv.Itoa(int(pod.GetPort())) reqCtx.RequestSize = 0 reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx) return nil diff --git a/pkg/epp/metrics/collectors/inference_pool_test.go b/pkg/epp/metrics/collectors/inference_pool_test.go index dcac3b37d..af2923e50 100644 --- a/pkg/epp/metrics/collectors/inference_pool_test.go +++ b/pkg/epp/metrics/collectors/inference_pool_test.go @@ -40,7 +40,7 @@ var ( Name: "pod1", }, } - pod1NamespacedName = types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace} + pod1NamespacedName = types.NamespacedName{Name: pod1.Name + "-rank-0", Namespace: pod1.Namespace} pod1Metrics = &backendmetrics.MetricsState{ WaitingQueueSize: 100, KVCacheUsagePercent: 0.2, @@ -50,10 +50,10 @@ var ( func TestNoMetricsCollected(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - datastore := datastore.NewDatastore(context.Background(), pmf) + ds := datastore.NewDatastore(context.Background(), pmf, 0) collector := &inferencePoolMetricsCollector{ - ds: datastore, + ds: ds, } if err := testutil.CollectAndCompare(collector, strings.NewReader(""), ""); err != nil { @@ -68,7 +68,7 @@ func TestMetricsCollected(t *testing.T) { }, } pmf := backendmetrics.NewPodMetricsFactory(pmc, time.Millisecond) - ds := datastore.NewDatastore(context.Background(), pmf) + ds := datastore.NewDatastore(context.Background(), pmf, 0) scheme := runtime.NewScheme() fakeClient := fake.NewClientBuilder(). @@ -94,7 +94,7 @@ func TestMetricsCollected(t *testing.T) { err := testutil.CollectAndCompare(collector, strings.NewReader(` # HELP inference_pool_per_pod_queue_size [ALPHA] The total number of requests pending in the model server queue for each underlying pod. # TYPE inference_pool_per_pod_queue_size gauge - inference_pool_per_pod_queue_size{model_server_pod="pod1",name="test-pool"} 100 + inference_pool_per_pod_queue_size{model_server_pod="pod1-rank-0",name="test-pool"} 100 `), "inference_pool_per_pod_queue_size") if err != nil { t.Fatal(err) diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index a3e2d6d13..42e20f764 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -196,7 +196,7 @@ func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMet podTotalCount := 0 podFilteredList := d.datastore.PodList(func(pm backendmetrics.PodMetrics) bool { podTotalCount++ - if _, found := endpoints[pm.GetPod().Address]; found { + if _, found := endpoints[pm.GetPod().GetIPAddress()]; found { return true } return false @@ -240,20 +240,12 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"} } // primary profile is used to set destination - pool, err := d.datastore.PoolGet() - if err != nil { - return reqCtx, err - } targetPods := []*backend.Pod{} - if len(pool.Spec.TargetPorts) != 1 { - return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "targetPorts should have length 1"} - } - targetPort := int(pool.Spec.TargetPorts[0].Number) targetEndpoints := []string{} for _, pod := range result.ProfileResults[result.PrimaryProfileName].TargetPods { curPod := pod.GetPod() - curEndpoint := net.JoinHostPort(curPod.Address, strconv.Itoa(targetPort)) + curEndpoint := net.JoinHostPort(curPod.GetIPAddress(), strconv.Itoa(int(curPod.GetPort()))) targetPods = append(targetPods, curPod) targetEndpoints = append(targetEndpoints, curEndpoint) } @@ -264,7 +256,7 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC reqCtx.TargetPod = targetPods[0] reqCtx.TargetEndpoint = multiEndpointString - d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result, targetPort) + d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result) return reqCtx, nil } @@ -302,12 +294,12 @@ func (d *Director) GetRandomPod() *backend.Pod { } func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, - schedulingResult *schedulingtypes.SchedulingResult, targetPort int) { + schedulingResult *schedulingtypes.SchedulingResult) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.preRequestPlugins { loggerDebug.Info("Running pre-request plugin", "plugin", plugin.TypedName()) before := time.Now() - plugin.PreRequest(ctx, request, schedulingResult, targetPort) + plugin.PreRequest(ctx, request, schedulingResult) metrics.RecordPluginProcessingLatency(PreRequestExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) loggerDebug.Info("Completed running pre-request plugin successfully", "plugin", plugin.TypedName()) } diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index a0cb7c325..493391c72 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -111,7 +111,7 @@ func TestDirector_HandleRequest(t *testing.T) { // Datastore setup pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) - ds := datastore.NewDatastore(t.Context(), pmf) + ds := datastore.NewDatastore(t.Context(), pmf, 0) ds.ObjectiveSet(ioFoodReview) ds.ObjectiveSet(ioFoodReviewResolve) ds.ObjectiveSet(ioFoodReviewSheddable) @@ -160,6 +160,8 @@ func TestDirector_HandleRequest(t *testing.T) { Pod: &schedulingtypes.PodMetrics{ Pod: &backend.Pod{ Address: "192.168.1.100", + Port: 8000, + MetricsPort: 8000, NamespacedName: types.NamespacedName{Name: "pod1", Namespace: "default"}, }, }, @@ -168,6 +170,8 @@ func TestDirector_HandleRequest(t *testing.T) { Pod: &schedulingtypes.PodMetrics{ Pod: &backend.Pod{ Address: "192.168.2.100", + Port: 8000, + MetricsPort: 8000, NamespacedName: types.NamespacedName{Name: "pod2", Namespace: "default"}, }, }, @@ -176,6 +180,8 @@ func TestDirector_HandleRequest(t *testing.T) { Pod: &schedulingtypes.PodMetrics{ Pod: &backend.Pod{ Address: "192.168.4.100", + Port: 8000, + MetricsPort: 8000, NamespacedName: types.NamespacedName{Name: "pod4", Namespace: "default"}, }, }, @@ -213,6 +219,8 @@ func TestDirector_HandleRequest(t *testing.T) { TargetPod: &backend.Pod{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", + Port: 8000, + MetricsPort: 8000, }, TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", }, @@ -240,6 +248,8 @@ func TestDirector_HandleRequest(t *testing.T) { TargetPod: &backend.Pod{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", + Port: 8000, + MetricsPort: 8000, }, TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", }, @@ -270,6 +280,8 @@ func TestDirector_HandleRequest(t *testing.T) { TargetPod: &backend.Pod{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", + Port: 8000, + MetricsPort: 8000, }, TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", }, @@ -293,6 +305,8 @@ func TestDirector_HandleRequest(t *testing.T) { TargetPod: &backend.Pod{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", + Port: 8000, + MetricsPort: 8000, }, TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", }, @@ -316,6 +330,8 @@ func TestDirector_HandleRequest(t *testing.T) { TargetPod: &backend.Pod{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", + Port: 8000, + MetricsPort: 8000, }, TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", }, @@ -334,6 +350,8 @@ func TestDirector_HandleRequest(t *testing.T) { TargetPod: &backend.Pod{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", + Port: 8000, + MetricsPort: 8000, }, TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", }, @@ -572,10 +590,29 @@ func TestGetRandomPod(t *testing.T) { }, } + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = v1alpha2.Install(scheme) + _ = v1.Install(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() + pool := &v1.InferencePool{ + Spec: v1.InferencePoolSpec{ + TargetPorts: []v1.Port{ + {Number: 8000}, + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Millisecond) - ds := datastore.NewDatastore(t.Context(), pmf) + ds := datastore.NewDatastore(t.Context(), pmf, 0) + err := ds.PoolSet(t.Context(), fakeClient, pool) + if err != nil { + t.Errorf("unexpected error setting pool: %s", err) + } for _, pod := range test.storePods { ds.PodUpdateOrAddIfNotExist(pod) } @@ -596,7 +633,7 @@ func TestDirector_HandleResponse(t *testing.T) { pr1 := newTestPostResponse("pr1") ctx := logutil.NewTestLoggerIntoContext(context.Background()) - ds := datastore.NewDatastore(t.Context(), nil) + ds := datastore.NewDatastore(t.Context(), nil, 0) mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponsePlugins(pr1)) diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index ca823a670..1b61b75d5 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -33,7 +33,7 @@ const ( // before a request is sent to the selected model server. type PreRequest interface { plugins.Plugin - PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int) + PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) } // PostResponse is called by the director after a successful response was sent. diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 6bc81a44a..c91a6f611 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -208,7 +208,7 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques } // PreRequest records in the plugin cache the result of the scheduling selection. -func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) { +func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName] targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 54a00abc1..11ef393ef 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -73,7 +73,7 @@ func TestPrefixPluginCompletion(t *testing.T) { "default": {TargetPods: []types.Pod{pod1}}, }, } - plugin.PreRequest(context.Background(), req1, schedulingResult, 0) + plugin.PreRequest(context.Background(), req1, schedulingResult) plugin.wg.Wait() // Second request doesn't share any prefix with first one. It should be added to the cache but @@ -105,7 +105,7 @@ func TestPrefixPluginCompletion(t *testing.T) { "default": {TargetPods: []types.Pod{pod2}}, }, } - plugin.PreRequest(context.Background(), req2, schedulingResult, 0) + plugin.PreRequest(context.Background(), req2, schedulingResult) plugin.wg.Wait() // Third request shares partial prefix with first one. @@ -135,7 +135,7 @@ func TestPrefixPluginCompletion(t *testing.T) { "default": {TargetPods: []types.Pod{pod1}}, }, } - plugin.PreRequest(context.Background(), req3, schedulingResult, 0) + plugin.PreRequest(context.Background(), req3, schedulingResult) plugin.wg.Wait() // 4th request is same as req3 except the model is different, still no match. @@ -165,7 +165,7 @@ func TestPrefixPluginCompletion(t *testing.T) { "default": {TargetPods: []types.Pod{pod1}}, }, } - plugin.PreRequest(context.Background(), req4, schedulingResult, 0) + plugin.PreRequest(context.Background(), req4, schedulingResult) plugin.wg.Wait() // 5th request shares partial prefix with 3rd one. @@ -195,7 +195,7 @@ func TestPrefixPluginCompletion(t *testing.T) { "default": {TargetPods: []types.Pod{pod1}}, }, } - plugin.PreRequest(context.Background(), req5, schedulingResult, 0) + plugin.PreRequest(context.Background(), req5, schedulingResult) plugin.wg.Wait() } @@ -275,7 +275,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { "default": {TargetPods: []types.Pod{pod1}}, }, } - plugin.PreRequest(context.Background(), req1, schedulingResult, 0) + plugin.PreRequest(context.Background(), req1, schedulingResult) plugin.wg.Wait() // Second request adds assistant response and new user message (conversation grows) @@ -308,7 +308,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit") // Simulate pod1 was picked again - plugin.PreRequest(context.Background(), req2, schedulingResult, 0) + plugin.PreRequest(context.Background(), req2, schedulingResult) plugin.wg.Wait() // Third request continues the conversation even further @@ -392,7 +392,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { "default": {TargetPods: []types.Pod{pod}}, }, } - plugin.PreRequest(context.Background(), req, schedulingResult, 0) + plugin.PreRequest(context.Background(), req, schedulingResult) plugin.wg.Wait() // Second cycle: validate internal state diff --git a/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go b/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go index bf36d7782..2836755d4 100644 --- a/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go +++ b/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go @@ -73,7 +73,7 @@ func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *types.CycleState podAddressMap := make(map[string]types.Pod, len(pods)) for _, pod := range pods { - podAddressMap[pod.GetPod().Address] = pod + podAddressMap[pod.GetPod().GetIPAddress()] = pod } endpoints := strings.Split(headerValue, ",") diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index aff6d4644..9dcd6def9 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -54,7 +54,6 @@ func TestServer(t *testing.T) { director := &testDirector{} ctx, cancel, ds, _ := utils.PrepareForTestStreamingServer([]*v1alpha2.InferenceObjective{model}, []*v1.Pod{{ObjectMeta: metav1.ObjectMeta{Name: podName}}}, "test-pool1", namespace, poolPort) - streamingServer := handlers.NewStreamingServer(ds, director) testListener, errChan := utils.SetupTestStreamingServer(t, ctx, ds, streamingServer) diff --git a/pkg/epp/util/testing/wrappers.go b/pkg/epp/util/testing/wrappers.go index 9e7f4a17b..7621bff96 100644 --- a/pkg/epp/util/testing/wrappers.go +++ b/pkg/epp/util/testing/wrappers.go @@ -179,7 +179,11 @@ func MakeInferencePool(name string) *InferencePoolWrapper { APIVersion: "inference.networking.k8s.io/v1", Kind: "InferencePool", }, - Spec: v1.InferencePoolSpec{}, + Spec: v1.InferencePoolSpec{ + TargetPorts: []v1.Port{ + {Number: 8000}, + }, + }, }, } } diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 3dc42f8ba..1d4821dbf 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -970,7 +970,7 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { responses, err := integrationutils.StreamedRequest(t, client, test.requests, len(test.wantResponses)) if err != nil && !test.wantErr { - t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) + t.Errorf("In test %s, unexpected error, got: %v, want error: %v", test.name, err, test.wantErr) } if diff := cmp.Diff(test.wantResponses, responses, protocmp.Transform(), @@ -978,13 +978,13 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { return a.GetHeader().GetKey() < b.GetHeader().GetKey() }), ); diff != "" { - t.Errorf("Unexpected response, (-want +got): %v", diff) + t.Errorf("In test %s, unexpected response, (-want +got): %v", test.name, diff) } if len(test.wantMetrics) != 0 { for metricName, value := range test.wantMetrics { if err := metricsutils.GatherAndCompare(crmetrics.Registry, strings.NewReader(value), metricName); err != nil { - t.Error(err) + t.Error(fmt.Errorf("In test %s, %v", test.name, err)) } } } @@ -1009,11 +1009,11 @@ func setUpHermeticServer(t *testing.T, podAndMetrics map[*backend.Pod]*backendme } for pod := range podAndMetrics { - pod := epptestutil.MakePod(pod.NamespacedName.Name). + pod := epptestutil.MakePod(pod.PodName). Namespace(pod.NamespacedName.Namespace). ReadyCondition(). Labels(podLabels). - IP(pod.Address). + IP(pod.GetIPAddress()). Complete(). ObjRef() @@ -1059,7 +1059,7 @@ func setUpHermeticServer(t *testing.T, podAndMetrics map[*backend.Pod]*backendme // clear created pods for pod := range podAndMetrics { - pod := epptestutil.MakePod(pod.NamespacedName.Name). + pod := epptestutil.MakePod(pod.PodName). Namespace(pod.NamespacedName.Namespace).Complete().ObjRef() if err := k8sClient.Delete(context.Background(), pod); err != nil { @@ -1071,8 +1071,9 @@ func setUpHermeticServer(t *testing.T, podAndMetrics map[*backend.Pod]*backendme func fakePod(index int) *backend.Pod { return &backend.Pod{ - NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v", index), Namespace: testNamespace}, + NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v-rank-0", index), Namespace: testNamespace}, Address: fmt.Sprintf("192.168.1.%d", index+1), + PodName: fmt.Sprintf("pod-%v", index), Labels: make(map[string]string, 0), } } @@ -1153,7 +1154,7 @@ func BeforeSuite() func() { NamespacedName: types.NamespacedName{Namespace: testNamespace, Name: testPoolName}, GroupKind: schema.GroupKind{Group: v1.GroupVersion.Group, Kind: "InferencePool"}, } - serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf) + serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf, 0) kvCacheUtilizationScorer := scorer.NewKVCacheUtilizationScorer() queueingScorer := scorer.NewQueueScorer() diff --git a/test/integration/util.go b/test/integration/util.go index d78b76e28..1ad631303 100644 --- a/test/integration/util.go +++ b/test/integration/util.go @@ -20,6 +20,7 @@ import ( "encoding/json" "io" "strconv" + "sync/atomic" "testing" "time" @@ -66,14 +67,14 @@ func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessCli // Make an incredible simple timeout func in the case where // there is less than the expected amount of responses; bail and fail. - var simpleTimeout bool + var simpleTimeout atomic.Bool go func() { time.Sleep(10 * time.Second) - simpleTimeout = true + simpleTimeout.Store(true) }() for range expectedResponses { - if simpleTimeout { + if simpleTimeout.Load() { break } res, err := client.Recv() diff --git a/test/utils/server.go b/test/utils/server.go index 51eb33fa0..9cf907d29 100644 --- a/test/utils/server.go +++ b/test/utils/server.go @@ -50,7 +50,7 @@ func PrepareForTestStreamingServer(objectives []*v1alpha2.InferenceObjective, po pmc := &metrics.FakePodMetricsClient{} pmf := metrics.NewPodMetricsFactory(pmc, time.Second) - ds := datastore.NewDatastore(ctx, pmf) + ds := datastore.NewDatastore(ctx, pmf, 0) initObjs := []client.Object{} for _, objective := range objectives {