Skip to content
4 changes: 1 addition & 3 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (r *Runner) Run(ctx context.Context) error {
if err != nil {
return err
}
datastore := datastore.NewDatastore(ctx, epf)
datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort))

// --- Setup Metrics Server ---
customCollectors := []prometheus.Collector{collectors.NewInferencePoolMetricsCollector(datastore)}
Expand Down Expand Up @@ -393,7 +393,6 @@ func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) {

pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{
MetricMapping: mapping,
ModelServerMetricsPort: int32(*modelServerMetricsPort),
ModelServerMetricsPath: *modelServerMetricsPath,
ModelServerMetricsScheme: *modelServerMetricsScheme,
Client: metricsHttpClient,
Expand All @@ -408,7 +407,6 @@ func setupDatalayer() (datalayer.EndpointFactory, error) {
// this (and registering the sources with the endpoint factory) should
// be moved accordingly.
source := dlmetrics.NewDataSource(*modelServerMetricsScheme,
int32(*modelServerMetricsPort), // start with (optional) command line port value
*modelServerMetricsPath,
*modelServerMetricsHttpsInsecureSkipVerify,
nil)
Expand Down
7 changes: 3 additions & 4 deletions pkg/epp/backend/metrics/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"sync"
"time"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"

Expand All @@ -49,8 +48,8 @@ func (fpm *FakePodMetrics) GetMetrics() *MetricsState {
return fpm.Metrics
}

func (fpm *FakePodMetrics) UpdatePod(pod *corev1.Pod) {
fpm.Pod = toInternalPod(pod)
func (fpm *FakePodMetrics) UpdatePod(pod *datalayer.PodInfo) {
fpm.Pod = pod
}

func (*FakePodMetrics) Put(string, datalayer.Cloneable) {}
Expand All @@ -69,7 +68,7 @@ type FakePodMetricsClient struct {
Res map[types.NamespacedName]*MetricsState
}

func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, _ int32) (*MetricsState, error) {
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
f.errMu.RLock()
err, ok := f.Err[pod.NamespacedName]
f.errMu.RUnlock()
Expand Down
12 changes: 4 additions & 8 deletions pkg/epp/backend/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,15 @@ const (

type PodMetricsClientImpl struct {
MetricMapping *MetricMapping
ModelServerMetricsPort int32
ModelServerMetricsPath string
ModelServerMetricsScheme string

Client *http.Client
}

// FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an updated one.
func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, port int32) (*MetricsState, error) {
url := p.getMetricEndpoint(pod, port)
func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
url := p.getMetricEndpoint(pod)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
Expand All @@ -74,11 +73,8 @@ func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Po
return p.promToPodMetrics(metricFamilies, existing)
}

func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod, targetPortNumber int32) string {
if p.ModelServerMetricsPort == 0 {
p.ModelServerMetricsPort = targetPortNumber
}
Comment on lines -78 to -80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that GetMetricsPort() does not implement this default behavior, which makes sense now that targetPortNumber is a list. We need to note this as a breaking change in the PR description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no breaking change. See the code in pkg/epp/datastore/datastore.go lines 242-244. If there is only one targetPort in the InferencePool and the ModelServerMetricsPort from the command line is not zero it will be used to fill the metricsPort in the PodInfo struct. The function GetMetricsPort() simply returns what was placed in the struct earlier.

return fmt.Sprintf("%s://%s:%d%s", p.ModelServerMetricsScheme, pod.Address, p.ModelServerMetricsPort, p.ModelServerMetricsPath)
func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod) string {
return fmt.Sprintf("%s://%s:%d%s", p.ModelServerMetricsScheme, pod.GetIPAddress(), pod.GetMetricsPort(), p.ModelServerMetricsPath)
}

// promToPodMetrics updates internal pod metrics with scraped Prometheus metrics.
Expand Down
7 changes: 4 additions & 3 deletions pkg/epp/backend/metrics/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,9 @@ func TestPromToPodMetrics(t *testing.T) {
func TestFetchMetrics(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())
pod := &backend.Pod{
Address: "127.0.0.1",
Address: "127.0.0.1",
Port: 9999,
MetricsPort: 9999,
NamespacedName: types.NamespacedName{
Namespace: "test",
Name: "pod",
Expand All @@ -499,12 +501,11 @@ func TestFetchMetrics(t *testing.T) {
// No MetricMapping needed for this basic test
p := &PodMetricsClientImpl{
ModelServerMetricsScheme: "http",
ModelServerMetricsPort: 9999,
ModelServerMetricsPath: "/metrics",
Client: http.DefaultClient,
}

_, err := p.FetchMetrics(ctx, pod, existing, 9999) // Use a port that's unlikely to be in use
_, err := p.FetchMetrics(ctx, pod, existing) // Use a port that's unlikely to be in use
if err == nil {
t.Errorf("FetchMetrics() expected error, got nil")
}
Expand Down
33 changes: 4 additions & 29 deletions pkg/epp/backend/metrics/pod_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import (
"time"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
Expand All @@ -51,7 +49,7 @@ type podMetrics struct {
}

type PodMetricsClient interface {
FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, port int32) (*MetricsState, error)
FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error)
}

func (pm *podMetrics) String() string {
Expand All @@ -66,23 +64,8 @@ func (pm *podMetrics) GetMetrics() *MetricsState {
return pm.metrics.Load()
}

func (pm *podMetrics) UpdatePod(pod *corev1.Pod) {
pm.pod.Store(toInternalPod(pod))
}

func toInternalPod(pod *corev1.Pod) *backend.Pod {
labels := make(map[string]string, len(pod.GetLabels()))
for key, value := range pod.GetLabels() {
labels[key] = value
}
return &backend.Pod{
NamespacedName: types.NamespacedName{
Name: pod.Name,
Namespace: pod.Namespace,
},
Address: pod.Status.PodIP,
Labels: labels,
}
func (pm *podMetrics) UpdatePod(pod *datalayer.PodInfo) {
pm.pod.Store(pod)
}

// start starts a goroutine exactly once to periodically update metrics. The goroutine will be
Expand Down Expand Up @@ -110,17 +93,9 @@ func (pm *podMetrics) startRefreshLoop(ctx context.Context) {
}

func (pm *podMetrics) refreshMetrics() error {
pool, err := pm.ds.PoolGet()
if err != nil {
// No inference pool or not initialize.
return err
}
ctx, cancel := context.WithTimeout(context.Background(), fetchMetricsTimeout)
defer cancel()
if len(pool.Spec.TargetPorts) != 1 {
return fmt.Errorf("expected 1 target port, got %d", len(pool.Spec.TargetPorts))
}
updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics(), int32(pool.Spec.TargetPorts[0].Number))
updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics())
if err != nil {
pm.logger.V(logutil.TRACE).Info("Failed to refreshed metrics:", "err", err)
}
Expand Down
17 changes: 8 additions & 9 deletions pkg/epp/backend/metrics/pod_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)

var (
pod1 = &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "pod1",
pod1Info = &datalayer.PodInfo{
NamespacedName: types.NamespacedName{
Name: "pod1-rank-0",
Namespace: "default",
},
PodName: "pod1",
}
initial = &MetricsState{
WaitingQueueSize: 0,
Expand Down Expand Up @@ -65,12 +65,11 @@ func TestMetricsRefresh(t *testing.T) {
pmf := NewPodMetricsFactory(pmc, time.Millisecond)

// The refresher is initialized with empty metrics.
pm := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{})
pm := pmf.NewEndpoint(ctx, pod1Info, &fakeDataStore{})

namespacedName := types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace}
// Use SetRes to simulate an update of metrics from the pod.
// Verify that the metrics are updated.
pmc.SetRes(map[types.NamespacedName]*MetricsState{namespacedName: initial})
pmc.SetRes(map[types.NamespacedName]*MetricsState{pod1Info.NamespacedName: initial})
condition := func(collect *assert.CollectT) {
assert.True(collect, cmp.Equal(pm.GetMetrics(), initial, cmpopts.IgnoreFields(MetricsState{}, "UpdateTime")))
}
Expand All @@ -80,7 +79,7 @@ func TestMetricsRefresh(t *testing.T) {
// new update.
pmf.ReleaseEndpoint(pm)
time.Sleep(pmf.refreshMetricsInterval * 2 /* small buffer for robustness */)
pmc.SetRes(map[types.NamespacedName]*MetricsState{namespacedName: updated})
pmc.SetRes(map[types.NamespacedName]*MetricsState{pod1Info.NamespacedName: updated})
// Still expect the same condition (no metrics update).
assert.EventuallyWithT(t, condition, time.Second, time.Millisecond)
}
Expand Down
4 changes: 1 addition & 3 deletions pkg/epp/backend/metrics/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"sync"
"time"

corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
Expand Down Expand Up @@ -53,8 +52,7 @@ type PodMetricsFactory struct {
refreshMetricsInterval time.Duration
}

func (f *PodMetricsFactory) NewEndpoint(parentCtx context.Context, in *corev1.Pod, ds datalayer.PoolInfo) PodMetrics {
pod := toInternalPod(in)
func (f *PodMetricsFactory) NewEndpoint(parentCtx context.Context, pod *datalayer.PodInfo, ds datalayer.PoolInfo) PodMetrics {
pm := &podMetrics{
pmc: f.pmc,
ds: ds,
Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/controller/inferenceobjective_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
WithObjects(initObjs...).
Build()
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
ds := datastore.NewDatastore(t.Context(), pmf)
ds := datastore.NewDatastore(t.Context(), pmf, 0)
for _, m := range test.objectivessInStore {
ds.ObjectiveSet(m)
}
Expand Down
24 changes: 12 additions & 12 deletions pkg/epp/controller/inferencepool_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ func TestInferencePoolReconciler(t *testing.T) {
ctx := context.Background()

pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
datastore := datastore.NewDatastore(ctx, pmf)
inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: datastore, PoolGKNN: gknn}
ds := datastore.NewDatastore(ctx, pmf, 0)
inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: ds, PoolGKNN: gknn}

// Step 1: Inception, only ready pods matching pool1 are added to the store.
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
if diff := diffStore(datastore, diffStoreParams{wantPool: pool1, wantPods: []string{"pod1", "pod2"}}); diff != "" {
if diff := diffStore(ds, diffStoreParams{wantPool: pool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}

Expand All @@ -138,7 +138,7 @@ func TestInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
if diff := diffStore(datastore, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" {
if diff := diffStore(ds, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}

Expand All @@ -153,7 +153,7 @@ func TestInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
if diff := diffStore(datastore, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" {
if diff := diffStore(ds, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}

Expand All @@ -167,7 +167,7 @@ func TestInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
if diff := diffStore(datastore, diffStoreParams{wantPods: []string{}}); diff != "" {
if diff := diffStore(ds, diffStoreParams{wantPods: []string{}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
}
Expand Down Expand Up @@ -258,14 +258,14 @@ func TestXInferencePoolReconciler(t *testing.T) {
ctx := context.Background()

pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
datastore := datastore.NewDatastore(ctx, pmf)
inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: datastore, PoolGKNN: gknn}
ds := datastore.NewDatastore(ctx, pmf, 0)
inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: ds, PoolGKNN: gknn}

// Step 1: Inception, only ready pods matching pool1 are added to the store.
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: pool1, wantPods: []string{"pod1", "pod2"}}); diff != "" {
if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: pool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}

Expand All @@ -281,7 +281,7 @@ func TestXInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" {
if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}

Expand All @@ -296,7 +296,7 @@ func TestXInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" {
if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}

Expand All @@ -310,7 +310,7 @@ func TestXInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPods: []string{}}); diff != "" {
if diff := xDiffStore(t, ds, xDiffStoreParams{wantPods: []string{}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
}
Expand Down
6 changes: 2 additions & 4 deletions pkg/epp/controller/pod_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/event"
Expand Down Expand Up @@ -53,7 +52,7 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R
pod := &corev1.Pod{}
if err := c.Get(ctx, req.NamespacedName, pod); err != nil {
if apierrors.IsNotFound(err) {
c.Datastore.PodDelete(req.NamespacedName)
c.Datastore.PodRemove(req.Name)
return ctrl.Result{}, nil
}
return ctrl.Result{}, fmt.Errorf("unable to get pod - %w", err)
Expand Down Expand Up @@ -90,10 +89,9 @@ func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error {
}

func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) {
namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace}
if !podutil.IsPodReady(pod) || !c.Datastore.PoolLabelsMatch(pod.Labels) {
logger.V(logutil.DEBUG).Info("Pod removed or not added")
c.Datastore.PodDelete(namespacedName)
c.Datastore.PodRemove(pod.Name)
} else {
if c.Datastore.PodUpdateOrAddIfNotExist(pod) {
logger.V(logutil.DEFAULT).Info("Pod added")
Expand Down
4 changes: 2 additions & 2 deletions pkg/epp/controller/pod_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestPodReconciler(t *testing.T) {
Build()

// Configure the initial state of the datastore.
store := datastore.NewDatastore(t.Context(), pmf)
store := datastore.NewDatastore(t.Context(), pmf, 0)
_ = store.PoolSet(t.Context(), fakeClient, test.pool)
for _, pod := range test.existingPods {
store.PodUpdateOrAddIfNotExist(pod)
Expand All @@ -213,7 +213,7 @@ func TestPodReconciler(t *testing.T) {

var gotPods []*corev1.Pod
for _, pm := range store.PodList(backendmetrics.AllPodsPredicate) {
pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().NamespacedName.Name, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().Address}}
pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().PodName, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().GetIPAddress()}}
gotPods = append(gotPods, pod)
}
if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) {
Expand Down
Loading