Skip to content

Commit 3c8aba1

Browse files
authored
Enable EPP to support endpoint discovery using pod selector (#1833)
* partial draft * refactor Signed-off-by: Xiyue Yu <[email protected]> * fixed some ut * make epp controller ut pass * make ut pass * fixed build * fixed build * fixed build failure * fixed lint * fix format * fixed import format * rename and refactor Signed-off-by: Xiyue Yu <[email protected]> * added epp name in env * rename to endpointPool * refactor in ut * fixed format * fixed format * changed error message Signed-off-by: Xiyue Yu <[email protected]> * changed error message Signed-off-by: Xiyue Yu <[email protected]> * debug * remove debug logging * fixed format * fixed import * updated to use epp name instead of pod name * fixed compiler * verify * don't set endpointpool in datastore for inferencepool at start * rename endpoints to endpointsmeta * rename import package * rename test utility * added logging info * change endpointpool struct * fixed variable naming * fixed linter --------- Signed-off-by: Xiyue Yu <[email protected]>
1 parent 3836d3b commit 3c8aba1

File tree

26 files changed

+494
-236
lines changed

26 files changed

+494
-236
lines changed

cmd/epp/runner/runner.go

Lines changed: 152 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ import (
2525
"net/http"
2626
"net/http/pprof"
2727
"os"
28+
"regexp"
2829
"runtime"
30+
"strconv"
31+
"strings"
2932
"sync/atomic"
3033

3134
"github.com/go-logr/logr"
@@ -34,16 +37,18 @@ import (
3437
"go.uber.org/zap/zapcore"
3538
"google.golang.org/grpc"
3639
healthPb "google.golang.org/grpc/health/grpc_health_v1"
40+
"k8s.io/apimachinery/pkg/labels"
3741
"k8s.io/apimachinery/pkg/runtime/schema"
3842
"k8s.io/apimachinery/pkg/types"
43+
"k8s.io/apimachinery/pkg/util/sets"
3944
"k8s.io/client-go/rest"
45+
4046
ctrl "sigs.k8s.io/controller-runtime"
4147
"sigs.k8s.io/controller-runtime/pkg/log"
4248
"sigs.k8s.io/controller-runtime/pkg/log/zap"
4349
"sigs.k8s.io/controller-runtime/pkg/manager"
4450
"sigs.k8s.io/controller-runtime/pkg/metrics/filters"
4551
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"
46-
4752
configapi "sigs.k8s.io/gateway-api-inference-extension/apix/config/v1alpha1"
4853
"sigs.k8s.io/gateway-api-inference-extension/internal/runnable"
4954
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
@@ -111,6 +116,8 @@ var (
111116
poolName = flag.String("pool-name", runserver.DefaultPoolName, "Name of the InferencePool this Endpoint Picker is associated with.")
112117
poolGroup = flag.String("pool-group", runserver.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.")
113118
poolNamespace = flag.String("pool-namespace", "", "Namespace of the InferencePool this Endpoint Picker is associated with.")
119+
endpointSelector = flag.String("endpoint-selector", "", "selector to filter model server pods on, only key=value paris is supported. Format: a comma-separated list of key value paris, e.g., 'app=vllm-llama3-8b-instruct,env=prod'.")
120+
endpointTargetPorts = flag.String("endpoint-target-ports", "", "target ports of model server pods. Format: a comma-separated list of numbers, e.g., '3000,3001,3002'")
114121
logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity")
115122
secureServing = flag.Bool("secure-serving", runserver.DefaultSecureServing, "Enables secure serving. Defaults to true.")
116123
healthChecking = flag.Bool("health-checking", runserver.DefaultHealthChecking, "Enables health checking")
@@ -231,16 +238,26 @@ func (r *Runner) Run(ctx context.Context) error {
231238
if err != nil {
232239
return err
233240
}
234-
datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort))
235241

236-
eppConfig, err := r.parseConfigurationPhaseTwo(ctx, rawConfig, datastore)
242+
gknn, err := extractGKNN(*poolName, *poolGroup, *poolNamespace, *endpointSelector)
243+
if err != nil {
244+
setupLog.Error(err, "Failed to extract GKNN")
245+
return err
246+
}
247+
disableK8sCrdReconcile := *endpointSelector != ""
248+
ds, err := setupDatastore(setupLog, ctx, epf, int32(*modelServerMetricsPort), disableK8sCrdReconcile, *poolName, *poolNamespace, *endpointSelector, *endpointTargetPorts)
249+
if err != nil {
250+
setupLog.Error(err, "Failed to setup datastore")
251+
return err
252+
}
253+
eppConfig, err := r.parseConfigurationPhaseTwo(ctx, rawConfig, ds)
237254
if err != nil {
238255
setupLog.Error(err, "Failed to parse configuration")
239256
return err
240257
}
241258

242259
// --- Setup Metrics Server ---
243-
r.customCollectors = append(r.customCollectors, collectors.NewInferencePoolMetricsCollector(datastore))
260+
r.customCollectors = append(r.customCollectors, collectors.NewInferencePoolMetricsCollector(ds))
244261
metrics.Register(r.customCollectors...)
245262
metrics.RecordInferenceExtensionInfo(version.CommitSHA, version.BuildRef)
246263
// Register metrics handler.
@@ -259,34 +276,10 @@ func (r *Runner) Run(ctx context.Context) error {
259276
}(),
260277
}
261278

262-
// Determine pool namespace: if --pool-namespace is non-empty, use it; else NAMESPACE env var; else default
263-
resolvePoolNamespace := func() string {
264-
if *poolNamespace != "" {
265-
return *poolNamespace
266-
}
267-
if nsEnv := os.Getenv("NAMESPACE"); nsEnv != "" {
268-
return nsEnv
269-
}
270-
return runserver.DefaultPoolNamespace
271-
}
272-
resolvedPoolNamespace := resolvePoolNamespace()
273-
poolNamespacedName := types.NamespacedName{
274-
Name: *poolName,
275-
Namespace: resolvedPoolNamespace,
276-
}
277-
poolGroupKind := schema.GroupKind{
278-
Group: *poolGroup,
279-
Kind: "InferencePool",
280-
}
281-
poolGKNN := common.GKNN{
282-
NamespacedName: poolNamespacedName,
283-
GroupKind: poolGroupKind,
284-
}
285-
286279
isLeader := &atomic.Bool{}
287280
isLeader.Store(false)
288281

289-
mgr, err := runserver.NewDefaultManager(poolGKNN, cfg, metricsServerOptions, *haEnableLeaderElection)
282+
mgr, err := runserver.NewDefaultManager(disableK8sCrdReconcile, *gknn, cfg, metricsServerOptions, *haEnableLeaderElection)
290283
if err != nil {
291284
setupLog.Error(err, "Failed to create controller manager")
292285
return err
@@ -353,14 +346,18 @@ func (r *Runner) Run(ctx context.Context) error {
353346
admissionController = requestcontrol.NewLegacyAdmissionController(saturationDetector)
354347
}
355348

356-
director := requestcontrol.NewDirectorWithConfig(datastore, scheduler, admissionController, r.requestControlConfig)
349+
director := requestcontrol.NewDirectorWithConfig(
350+
ds,
351+
scheduler,
352+
admissionController,
353+
r.requestControlConfig)
357354

358355
// --- Setup ExtProc Server Runner ---
359356
serverRunner := &runserver.ExtProcServerRunner{
360357
GrpcPort: *grpcPort,
361-
PoolNamespacedName: poolNamespacedName,
362-
PoolGKNN: poolGKNN,
363-
Datastore: datastore,
358+
GKNN: *gknn,
359+
Datastore: ds,
360+
DisableK8sCrdReconcile: disableK8sCrdReconcile,
364361
SecureServing: *secureServing,
365362
HealthChecking: *healthChecking,
366363
CertPath: *certPath,
@@ -377,7 +374,7 @@ func (r *Runner) Run(ctx context.Context) error {
377374

378375
// --- Add Runnables to Manager ---
379376
// Register health server.
380-
if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), datastore, *grpcHealthPort, isLeader, *haEnableLeaderElection); err != nil {
377+
if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), ds, *grpcHealthPort, isLeader, *haEnableLeaderElection); err != nil {
381378
return err
382379
}
383380

@@ -397,6 +394,28 @@ func (r *Runner) Run(ctx context.Context) error {
397394
return nil
398395
}
399396

397+
func setupDatastore(setupLog logr.Logger, ctx context.Context, epFactory datalayer.EndpointFactory, modelServerMetricsPort int32, disableK8sCrdReconcile bool, namespace, name, endpointSelector, endpointTargetPorts string) (datastore.Datastore, error) {
398+
if !disableK8sCrdReconcile {
399+
return datastore.NewDatastore(ctx, epFactory, modelServerMetricsPort), nil
400+
} else {
401+
endpointPool := datalayer.NewEndpointPool(namespace, name)
402+
labelsMap, err := labels.ConvertSelectorToLabelsMap(endpointSelector)
403+
if err != nil {
404+
setupLog.Error(err, "Failed to parse flag %q with error: %w", "endpoint-selector", err)
405+
return nil, err
406+
}
407+
endpointPool.Selector = labelsMap
408+
endpointPool.TargetPorts, err = strToUniqueIntSlice(endpointTargetPorts)
409+
if err != nil {
410+
setupLog.Error(err, "Failed to parse flag %q with error: %w", "endpoint-target-ports", err)
411+
return nil, err
412+
}
413+
414+
endpointPoolOption := datastore.WithEndpointPool(endpointPool)
415+
return datastore.NewDatastore(ctx, epFactory, modelServerMetricsPort, endpointPoolOption), nil
416+
}
417+
}
418+
400419
// registerInTreePlugins registers the factory functions of all known plugins
401420
func (r *Runner) registerInTreePlugins() {
402421
plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory)
@@ -635,9 +654,19 @@ func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds datastore.
635654
}
636655

637656
func validateFlags() error {
638-
if *poolName == "" {
639-
return fmt.Errorf("required %q flag not set", "poolName")
657+
if (*poolName != "" && *endpointSelector != "") || (*poolName == "" && *endpointSelector == "") {
658+
return errors.New("either pool-name or endpoint-selector must be set")
640659
}
660+
if *endpointSelector != "" {
661+
targetPortsList, err := strToUniqueIntSlice(*endpointTargetPorts)
662+
if err != nil {
663+
return fmt.Errorf("unexpected value for %q flag with error %w", "endpoint-target-ports", err)
664+
}
665+
if len(targetPortsList) == 0 || len(targetPortsList) > 8 {
666+
return fmt.Errorf("flag %q should have length from 1 to 8", "endpoint-target-ports")
667+
}
668+
}
669+
641670
if *configText != "" && *configFile != "" {
642671
return fmt.Errorf("both the %q and %q flags can not be set at the same time", "configText", "configFile")
643672
}
@@ -648,6 +677,34 @@ func validateFlags() error {
648677
return nil
649678
}
650679

680+
func strToUniqueIntSlice(s string) ([]int, error) {
681+
seen := sets.NewInt()
682+
var intList []int
683+
684+
if s == "" {
685+
return intList, nil
686+
}
687+
688+
strList := strings.Split(s, ",")
689+
690+
for _, str := range strList {
691+
trimmedStr := strings.TrimSpace(str)
692+
if trimmedStr == "" {
693+
continue
694+
}
695+
portInt, err := strconv.Atoi(trimmedStr)
696+
if err != nil {
697+
return nil, fmt.Errorf("invalid number: '%s' is not an integer", trimmedStr)
698+
}
699+
700+
if _, ok := seen[portInt]; !ok {
701+
seen[portInt] = struct{}{}
702+
intList = append(intList, portInt)
703+
}
704+
}
705+
return intList, nil
706+
}
707+
651708
func verifyMetricMapping(mapping backendmetrics.MetricMapping, logger logr.Logger) {
652709
if mapping.TotalQueuedRequests == nil {
653710
logger.Info("Not scraping metric: TotalQueuedRequests")
@@ -683,3 +740,62 @@ func setupPprofHandlers(mgr ctrl.Manager) error {
683740
}
684741
return nil
685742
}
743+
744+
func extractDeploymentName(podName string) (string, error) {
745+
regex := regexp.MustCompile(`^(.+)-[a-z0-9]+-[a-z0-9]+$`)
746+
747+
matches := regex.FindStringSubmatch(podName)
748+
if len(matches) == 2 {
749+
return matches[1], nil
750+
}
751+
return "", fmt.Errorf("failed to parse deployment name from pod name %s", podName)
752+
}
753+
754+
func extractGKNN(poolName, poolGroup, poolNamespace, endpointSelector string) (*common.GKNN, error) {
755+
if poolName != "" {
756+
// Determine pool namespace: if --pool-namespace is non-empty, use it; else NAMESPACE env var; else default
757+
resolvedPoolNamespace := resolvePoolNamespace(poolNamespace)
758+
poolNamespacedName := types.NamespacedName{
759+
Name: poolName,
760+
Namespace: resolvedPoolNamespace,
761+
}
762+
poolGroupKind := schema.GroupKind{
763+
Group: poolGroup,
764+
Kind: "InferencePool",
765+
}
766+
return &common.GKNN{
767+
NamespacedName: poolNamespacedName,
768+
GroupKind: poolGroupKind,
769+
}, nil
770+
}
771+
772+
if endpointSelector != "" {
773+
// Determine EPP namespace: NAMESPACE env var; else default
774+
resolvedPoolNamespace := resolvePoolNamespace(poolNamespace)
775+
// Determine EPP name: POD_NAME env var
776+
eppPodNameEnv := os.Getenv("POD_NAME")
777+
if eppPodNameEnv == "" {
778+
return nil, errors.New("failed to get environment variable POD_NAME")
779+
780+
}
781+
eppName, err := extractDeploymentName(eppPodNameEnv)
782+
if err != nil {
783+
return nil, err
784+
}
785+
return &common.GKNN{
786+
NamespacedName: types.NamespacedName{Namespace: resolvedPoolNamespace, Name: eppName},
787+
GroupKind: schema.GroupKind{Kind: "Deployment", Group: "apps"},
788+
}, nil
789+
}
790+
return nil, errors.New("can't construct gknn as both pool-name and endpoint-selector are missing")
791+
}
792+
793+
func resolvePoolNamespace(poolNamespace string) string {
794+
if poolNamespace != "" {
795+
return poolNamespace
796+
}
797+
if nsEnv := os.Getenv("NAMESPACE"); nsEnv != "" {
798+
return nsEnv
799+
}
800+
return runserver.DefaultPoolNamespace
801+
}

config/charts/inferencepool/templates/epp-deployment.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ spec:
147147
valueFrom:
148148
fieldRef:
149149
fieldPath: metadata.namespace
150+
- name: POD_NAME
151+
valueFrom:
152+
fieldRef:
153+
fieldPath: metadata.name
150154
{{- if .Values.inferenceExtension.tracing.enabled }}
151155
- name: OTEL_SERVICE_NAME
152156
value: "gateway-api-inference-extension"

pkg/epp/backend/metrics/pod_metrics_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"github.com/stretchr/testify/assert"
2626
"k8s.io/apimachinery/pkg/types"
2727

28-
v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
2928
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3029
)
3130

@@ -86,8 +85,8 @@ func TestMetricsRefresh(t *testing.T) {
8685

8786
type fakeDataStore struct{}
8887

89-
func (f *fakeDataStore) PoolGet() (*v1.InferencePool, error) {
90-
return &v1.InferencePool{Spec: v1.InferencePoolSpec{TargetPorts: []v1.Port{{Number: 8000}}}}, nil
88+
func (f *fakeDataStore) PoolGet() (*datalayer.EndpointPool, error) {
89+
return &datalayer.EndpointPool{}, nil
9190
}
9291

9392
func (f *fakeDataStore) PodList(func(PodMetrics) bool) []PodMetrics {

pkg/epp/controller/inferenceobjective_reconciler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ import (
2121
"fmt"
2222

2323
"k8s.io/apimachinery/pkg/api/errors"
24+
2425
ctrl "sigs.k8s.io/controller-runtime"
2526
"sigs.k8s.io/controller-runtime/pkg/client"
2627
"sigs.k8s.io/controller-runtime/pkg/event"
2728
"sigs.k8s.io/controller-runtime/pkg/log"
2829
"sigs.k8s.io/controller-runtime/pkg/predicate"
29-
3030
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
3131
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
3232
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
@@ -55,7 +55,7 @@ func (c *InferenceObjectiveReconciler) Reconcile(ctx context.Context, req ctrl.R
5555
}
5656

5757
if notFound || !infObjective.DeletionTimestamp.IsZero() || infObjective.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) || infObjective.Spec.PoolRef.Group != v1alpha2.Group(c.PoolGKNN.Group) {
58-
// InferenceObjective object got deleted or changed the referenced pool.
58+
// InferenceObjective object got deleted or changed the referenced inferencePool.
5959
c.Datastore.ObjectiveDelete(req.NamespacedName)
6060
return ctrl.Result{}, nil
6161
}

0 commit comments

Comments
 (0)