Skip to content

Commit 61e3d02

Browse files
shmuelkBenjaminBraunDev
authored andcommitted
Support for vLLM Data parallel (kubernetes-sigs#1663)
* Removed global inference port from Prerequest extension API Signed-off-by: Shmuel Kallner <[email protected]> * Inference port and metrics port now per pod Signed-off-by: Shmuel Kallner <[email protected]> * Differentiate between real pod delete and virtual pod delete Signed-off-by: Shmuel Kallner <[email protected]> * Pass default metrics port to datastore Signed-off-by: Shmuel Kallner <[email protected]> * Updates to reflect newer APIs Signed-off-by: Shmuel Kallner <[email protected]> * Updates to tests Signed-off-by: Shmuel Kallner <[email protected]> * Fail tests that have errors, don't just log the errors Signed-off-by: Shmuel Kallner <[email protected]> * Remove tests that are no longer applicable Signed-off-by: Shmuel Kallner <[email protected]> * Set an InferencePool into the datastore Signed-off-by: Shmuel Kallner <[email protected]> * Added tests with multiple TargetPorts Signed-off-by: Shmuel Kallner <[email protected]> * Fix lint issues Signed-off-by: Shmuel Kallner <[email protected]> * Updated a new test due to updated interface Signed-off-by: Shmuel Kallner <[email protected]> * Store inference port and metrics host as strings Signed-off-by: Shmuel Kallner <[email protected]> * Concatenate metrics URL parts together without fmt.Sprintf Signed-off-by: Shmuel Kallner <[email protected]> * Use already stored metrics host Signed-off-by: Shmuel Kallner <[email protected]> * No need to convert inference port to a string Signed-off-by: Shmuel Kallner <[email protected]> * Updates due to PodInfo changes Signed-off-by: Shmuel Kallner <[email protected]> * Test updates due to PodInfo changes Signed-off-by: Shmuel Kallner <[email protected]> * Merged PodRemove into PodDelete Signed-off-by: Shmuel Kallner <[email protected]> * Changes due to merging of PodRemove into PodDelete Signed-off-by: Shmuel Kallner <[email protected]> * Test changes due to merging of PodRemove into PodDelete Signed-off-by: Shmuel Kallner <[email protected]> --------- Signed-off-by: Shmuel Kallner <[email protected]>
1 parent 2141040 commit 61e3d02

File tree

31 files changed

+527
-499
lines changed

31 files changed

+527
-499
lines changed

cmd/epp/runner/runner.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func (r *Runner) Run(ctx context.Context) error {
208208
if err != nil {
209209
return err
210210
}
211-
datastore := datastore.NewDatastore(ctx, epf)
211+
datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort))
212212

213213
// --- Setup Metrics Server ---
214214
customCollectors := []prometheus.Collector{collectors.NewInferencePoolMetricsCollector(datastore)}
@@ -514,7 +514,6 @@ func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) {
514514

515515
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{
516516
MetricMapping: mapping,
517-
ModelServerMetricsPort: int32(*modelServerMetricsPort),
518517
ModelServerMetricsPath: *modelServerMetricsPath,
519518
ModelServerMetricsScheme: *modelServerMetricsScheme,
520519
Client: metricsHttpClient,
@@ -529,7 +528,6 @@ func setupDatalayer() (datalayer.EndpointFactory, error) {
529528
// this (and registering the sources with the endpoint factory) should
530529
// be moved accordingly.
531530
source := dlmetrics.NewDataSource(*modelServerMetricsScheme,
532-
int32(*modelServerMetricsPort), // start with (optional) command line port value
533531
*modelServerMetricsPath,
534532
*modelServerMetricsHttpsInsecureSkipVerify,
535533
nil)

pkg/epp/backend/metrics/fake.go

Lines changed: 3 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"sync"
2323
"time"
2424

25-
corev1 "k8s.io/api/core/v1"
2625
"k8s.io/apimachinery/pkg/types"
2726
"sigs.k8s.io/controller-runtime/pkg/log"
2827

@@ -52,100 +51,8 @@ func (fpm *FakePodMetrics) GetMetrics() *MetricsState {
5251
return fpm.Metrics
5352
}
5453

55-
func (fpm *FakePodMetrics) UpdatePod(pod *corev1.Pod) {
56-
fpm.Pod = toInternalPod(pod, nil)
57-
}
58-
59-
func (f *FakePodMetrics) StopRefreshLoop() {
60-
f.mu.Lock()
61-
defer f.mu.Unlock()
62-
f.stopped = true
63-
}
64-
65-
func (f *FakePodMetrics) GetRunningRequests() *datalayer.RequestPriorityQueue {
66-
f.mu.RLock()
67-
defer f.mu.RUnlock()
68-
if f.stopped {
69-
return nil // Return nil for stopped pod metrics
70-
}
71-
return f.runningRequests
72-
}
73-
74-
func (f *FakePodMetrics) AddRequest(requestID string, tpot float64) bool {
75-
f.mu.RLock()
76-
defer f.mu.RUnlock()
77-
if f.stopped {
78-
return false // Reject operations after stopped
79-
}
80-
return f.runningRequests.Add(requestID, tpot)
81-
}
82-
83-
func (f *FakePodMetrics) RemoveRequest(requestID string) bool {
84-
f.mu.RLock()
85-
defer f.mu.RUnlock()
86-
if f.stopped {
87-
return false // Reject operations after stopped
88-
}
89-
_, success := f.runningRequests.Remove(requestID)
90-
return success
91-
}
92-
93-
func (f *FakePodMetrics) UpdateRequest(requestID string, tpot float64) bool {
94-
f.mu.RLock()
95-
defer f.mu.RUnlock()
96-
if f.stopped {
97-
return false // Reject operations after stopped
98-
}
99-
return f.runningRequests.Update(requestID, tpot)
100-
}
101-
102-
func (f *FakePodMetrics) GetRequestCount() int {
103-
f.mu.RLock()
104-
defer f.mu.RUnlock()
105-
if f.stopped {
106-
return 0 // Return 0 after stopped
107-
}
108-
return f.runningRequests.GetSize()
109-
}
110-
111-
func (f *FakePodMetrics) ContainsRequest(requestID string) bool {
112-
pod := f.GetPod()
113-
if pod == nil || pod.RunningRequests == nil {
114-
return false
115-
}
116-
return pod.RunningRequests.Contains(requestID)
117-
}
118-
119-
func (srv *FakePodMetrics) PeekRequestPriorityQueue() *datalayer.Request {
120-
pod := srv.GetPod()
121-
if pod == nil || pod.RunningRequests == nil {
122-
return nil
123-
}
124-
return pod.RunningRequests.Peek()
125-
}
126-
127-
func NewFakePodMetrics(k8sPod *corev1.Pod) *FakePodMetrics {
128-
labels := make(map[string]string)
129-
for k, v := range k8sPod.Labels {
130-
labels[k] = v
131-
}
132-
133-
pod := &backend.Pod{
134-
NamespacedName: types.NamespacedName{
135-
Name: k8sPod.Name,
136-
Namespace: k8sPod.Namespace,
137-
},
138-
Address: k8sPod.Status.PodIP,
139-
Labels: labels,
140-
RunningRequests: datalayer.NewRequestPriorityQueue(),
141-
}
142-
143-
return &FakePodMetrics{
144-
Pod: pod,
145-
Metrics: &MetricsState{UpdateTime: time.Now()},
146-
runningRequests: datalayer.NewRequestPriorityQueue(),
147-
stopped: false,
148-
}
54+
func (fpm *FakePodMetrics) UpdatePod(pod *datalayer.PodInfo) {
55+
fpm.Pod = pod
14956
}
15057

15158
func (*FakePodMetrics) Put(string, datalayer.Cloneable) {}
@@ -164,7 +71,7 @@ type FakePodMetricsClient struct {
16471
Res map[types.NamespacedName]*MetricsState
16572
}
16673

167-
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, _ int32) (*MetricsState, error) {
74+
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
16875
f.errMu.RLock()
16976
err, ok := f.Err[pod.NamespacedName]
17077
f.errMu.RUnlock()

pkg/epp/backend/metrics/metrics.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,15 @@ const (
4242

4343
type PodMetricsClientImpl struct {
4444
MetricMapping *MetricMapping
45-
ModelServerMetricsPort int32
4645
ModelServerMetricsPath string
4746
ModelServerMetricsScheme string
4847

4948
Client *http.Client
5049
}
5150

5251
// FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an updated one.
53-
func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, port int32) (*MetricsState, error) {
54-
url := p.getMetricEndpoint(pod, port)
52+
func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
53+
url := p.getMetricEndpoint(pod)
5554
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
5655
if err != nil {
5756
return nil, fmt.Errorf("failed to create request: %v", err)
@@ -76,11 +75,8 @@ func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Po
7675
return p.promToPodMetrics(metricFamilies, existing)
7776
}
7877

79-
func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod, targetPortNumber int32) string {
80-
if p.ModelServerMetricsPort == 0 {
81-
p.ModelServerMetricsPort = targetPortNumber
82-
}
83-
return fmt.Sprintf("%s://%s:%d%s", p.ModelServerMetricsScheme, pod.Address, p.ModelServerMetricsPort, p.ModelServerMetricsPath)
78+
func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod) string {
79+
return p.ModelServerMetricsScheme + "://" + pod.GetMetricsHost() + p.ModelServerMetricsPath
8480
}
8581

8682
// promToPodMetrics updates internal pod metrics with scraped Prometheus metrics.

pkg/epp/backend/metrics/metrics_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,9 @@ func TestPromToPodMetrics(t *testing.T) {
489489
func TestFetchMetrics(t *testing.T) {
490490
ctx := logutil.NewTestLoggerIntoContext(context.Background())
491491
pod := &backend.Pod{
492-
Address: "127.0.0.1",
492+
Address: "127.0.0.1",
493+
Port: "9999",
494+
MetricsHost: "127.0.0.1:9999",
493495
NamespacedName: types.NamespacedName{
494496
Namespace: "test",
495497
Name: "pod",
@@ -499,12 +501,11 @@ func TestFetchMetrics(t *testing.T) {
499501
// No MetricMapping needed for this basic test
500502
p := &PodMetricsClientImpl{
501503
ModelServerMetricsScheme: "http",
502-
ModelServerMetricsPort: 9999,
503504
ModelServerMetricsPath: "/metrics",
504505
Client: http.DefaultClient,
505506
}
506507

507-
_, err := p.FetchMetrics(ctx, pod, existing, 9999) // Use a port that's unlikely to be in use
508+
_, err := p.FetchMetrics(ctx, pod, existing) // Use a port that's unlikely to be in use
508509
if err == nil {
509510
t.Errorf("FetchMetrics() expected error, got nil")
510511
}

pkg/epp/backend/metrics/pod_metrics.go

Lines changed: 4 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ import (
2424
"time"
2525

2626
"github.com/go-logr/logr"
27-
corev1 "k8s.io/api/core/v1"
28-
"k8s.io/apimachinery/pkg/types"
2927

3028
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3129
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
@@ -51,7 +49,7 @@ type podMetrics struct {
5149
}
5250

5351
type PodMetricsClient interface {
54-
FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, port int32) (*MetricsState, error)
52+
FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error)
5553
}
5654

5755
func (pm *podMetrics) String() string {
@@ -66,98 +64,8 @@ func (pm *podMetrics) GetMetrics() *MetricsState {
6664
return pm.metrics.Load()
6765
}
6866

69-
// New methods for priority queue integration
70-
func (pm *podMetrics) GetRunningRequests() *datalayer.RequestPriorityQueue {
71-
pod := pm.GetPod()
72-
if pod == nil {
73-
return nil
74-
}
75-
return pod.RunningRequests
76-
}
77-
78-
func (pm *podMetrics) AddRequest(requestID string, tpot float64) bool {
79-
pod := pm.GetPod()
80-
if pod == nil || pod.RunningRequests == nil {
81-
return false
82-
}
83-
success := pod.RunningRequests.Add(requestID, tpot)
84-
// No need to update metrics since we removed ActualRunningRequests
85-
return success
86-
}
87-
88-
func (pm *podMetrics) RemoveRequest(requestID string) bool {
89-
pod := pm.GetPod()
90-
if pod == nil || pod.RunningRequests == nil {
91-
return false
92-
}
93-
_, success := pod.RunningRequests.Remove(requestID)
94-
// No need to update metrics since we removed ActualRunningRequests
95-
return success
96-
}
97-
98-
func (pm *podMetrics) UpdateRequest(requestID string, tpot float64) bool {
99-
pod := pm.GetPod()
100-
if pod == nil || pod.RunningRequests == nil {
101-
return false
102-
}
103-
return pod.RunningRequests.Update(requestID, tpot)
104-
}
105-
106-
func (pm *podMetrics) GetRequestCount() int {
107-
pod := pm.GetPod()
108-
if pod == nil || pod.RunningRequests == nil {
109-
return 0
110-
}
111-
return pod.RunningRequests.GetSize()
112-
}
113-
114-
func (pm *podMetrics) ContainsRequest(requestID string) bool {
115-
pod := pm.GetPod()
116-
if pod == nil || pod.RunningRequests == nil {
117-
return false
118-
}
119-
return pod.RunningRequests.Contains(requestID)
120-
}
121-
122-
func (pm *podMetrics) PeekRequestPriorityQueue() *datalayer.Request {
123-
pod := pm.GetPod()
124-
if pod == nil || pod.RunningRequests == nil {
125-
return nil
126-
}
127-
return pod.RunningRequests.Peek()
128-
}
129-
130-
func (pm *podMetrics) UpdatePod(k8sPod *corev1.Pod) {
131-
currentPod := pm.GetPod()
132-
updatedPod := toInternalPod(k8sPod, currentPod.GetRunningRequests())
133-
134-
// Preserve the existing running requests queue if it exists
135-
if currentPod != nil && currentPod.GetRunningRequests() != nil {
136-
updatedPod.RunningRequests = currentPod.GetRunningRequests()
137-
}
138-
139-
pm.pod.Store(updatedPod)
140-
}
141-
func toInternalPod(pod *corev1.Pod, existingQueue *datalayer.RequestPriorityQueue) *backend.Pod {
142-
labels := make(map[string]string, len(pod.GetLabels()))
143-
for key, value := range pod.GetLabels() {
144-
labels[key] = value
145-
}
146-
147-
queue := existingQueue
148-
if queue == nil {
149-
queue = datalayer.NewRequestPriorityQueue()
150-
}
151-
152-
return &backend.Pod{
153-
NamespacedName: types.NamespacedName{
154-
Name: pod.Name,
155-
Namespace: pod.Namespace,
156-
},
157-
Address: pod.Status.PodIP,
158-
Labels: labels,
159-
RunningRequests: queue,
160-
}
67+
func (pm *podMetrics) UpdatePod(pod *datalayer.PodInfo) {
68+
pm.pod.Store(pod)
16169
}
16270

16371
// start starts a goroutine exactly once to periodically update metrics. The goroutine will be
@@ -185,17 +93,9 @@ func (pm *podMetrics) startRefreshLoop(ctx context.Context) {
18593
}
18694

18795
func (pm *podMetrics) refreshMetrics() error {
188-
pool, err := pm.ds.PoolGet()
189-
if err != nil {
190-
// No inference pool or not initialize.
191-
return err
192-
}
19396
ctx, cancel := context.WithTimeout(context.Background(), fetchMetricsTimeout)
19497
defer cancel()
195-
if len(pool.Spec.TargetPorts) != 1 {
196-
return fmt.Errorf("expected 1 target port, got %d", len(pool.Spec.TargetPorts))
197-
}
198-
updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics(), int32(pool.Spec.TargetPorts[0].Number))
98+
updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics())
19999
if err != nil {
200100
pm.logger.V(logutil.TRACE).Info("Failed to refreshed metrics:", "err", err)
201101
}

pkg/epp/backend/metrics/pod_metrics_test.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@ import (
2525
"github.com/google/go-cmp/cmp"
2626
"github.com/google/go-cmp/cmp/cmpopts"
2727
"github.com/stretchr/testify/assert"
28-
corev1 "k8s.io/api/core/v1"
29-
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3028
"k8s.io/apimachinery/pkg/types"
3129

3230
v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
31+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3332
)
3433

3534
var (
36-
pod1 = &corev1.Pod{
37-
ObjectMeta: metav1.ObjectMeta{
38-
Name: "pod1",
35+
pod1Info = &datalayer.PodInfo{
36+
NamespacedName: types.NamespacedName{
37+
Name: "pod1-rank-0",
3938
Namespace: "default",
4039
Labels: map[string]string{"app": "test"},
4140
},
4241
Status: corev1.PodStatus{
4342
PodIP: "192.168.1.1",
4443
},
44+
PodName: "pod1",
4545
}
4646
initial = &MetricsState{
4747
WaitingQueueSize: 0,
@@ -71,12 +71,11 @@ func TestMetricsRefresh(t *testing.T) {
7171
pmf := NewPodMetricsFactory(pmc, time.Millisecond)
7272

7373
// The refresher is initialized with empty metrics.
74-
pm := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{})
74+
pm := pmf.NewEndpoint(ctx, pod1Info, &fakeDataStore{})
7575

76-
namespacedName := types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace}
7776
// Use SetRes to simulate an update of metrics from the pod.
7877
// Verify that the metrics are updated.
79-
pmc.SetRes(map[types.NamespacedName]*MetricsState{namespacedName: initial})
78+
pmc.SetRes(map[types.NamespacedName]*MetricsState{pod1Info.NamespacedName: initial})
8079
condition := func(collect *assert.CollectT) {
8180
assert.True(collect, cmp.Equal(pm.GetMetrics(), initial, cmpopts.IgnoreFields(MetricsState{}, "UpdateTime")))
8281
}
@@ -86,7 +85,7 @@ func TestMetricsRefresh(t *testing.T) {
8685
// new update.
8786
pmf.ReleaseEndpoint(pm)
8887
time.Sleep(pmf.refreshMetricsInterval * 2 /* small buffer for robustness */)
89-
pmc.SetRes(map[types.NamespacedName]*MetricsState{namespacedName: updated})
88+
pmc.SetRes(map[types.NamespacedName]*MetricsState{pod1Info.NamespacedName: updated})
9089
// Still expect the same condition (no metrics update).
9190
assert.EventuallyWithT(t, condition, time.Second, time.Millisecond)
9291
}

0 commit comments

Comments
 (0)