Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5945cfb
partial draft
capri-xiyue Oct 31, 2025
370f1a3
refactor
capri-xiyue Nov 2, 2025
85e4622
fixed some ut
capri-xiyue Nov 6, 2025
579b17e
make epp controller ut pass
capri-xiyue Nov 6, 2025
e9d704d
make ut pass
capri-xiyue Nov 6, 2025
42cb218
fixed build
capri-xiyue Nov 7, 2025
943c53d
fixed build
capri-xiyue Nov 7, 2025
1e38821
fixed build failure
capri-xiyue Nov 10, 2025
30fd667
fixed lint
capri-xiyue Nov 10, 2025
9e85377
fix format
capri-xiyue Nov 10, 2025
43c87fa
merge conflicts
capri-xiyue Nov 14, 2025
0267569
fixed import format
capri-xiyue Nov 14, 2025
200dbf4
rename and refactor
capri-xiyue Nov 17, 2025
9d1514a
added epp name in env
capri-xiyue Nov 17, 2025
ba89d24
rename to endpointPool
capri-xiyue Nov 18, 2025
568e9ee
merge conflict
capri-xiyue Nov 18, 2025
ba90213
refactor in ut
capri-xiyue Nov 18, 2025
bc6a4c6
fixed format
capri-xiyue Nov 18, 2025
93ab791
fixed format
capri-xiyue Nov 18, 2025
c3ac6f2
changed error message
capri-xiyue Nov 18, 2025
c4b8c32
changed error message
capri-xiyue Nov 18, 2025
513590d
debug
capri-xiyue Nov 18, 2025
84b2275
remove debug logging
capri-xiyue Nov 18, 2025
0bd692f
fixed format
capri-xiyue Nov 18, 2025
9b7656d
merge conflicts
capri-xiyue Nov 18, 2025
72bb590
fixed import
capri-xiyue Nov 18, 2025
71f9fe5
updated to use epp name instead of pod name
capri-xiyue Nov 18, 2025
770f96c
fixed compiler
capri-xiyue Nov 18, 2025
a96edb0
verify
capri-xiyue Nov 18, 2025
1f15673
don't set endpointpool in datastore for inferencepool at start
capri-xiyue Nov 18, 2025
d079423
rename endpoints to endpointsmeta
capri-xiyue Nov 18, 2025
76b4eaa
rename import package
capri-xiyue Nov 18, 2025
2c67bf8
rename test utility
capri-xiyue Nov 18, 2025
7798ebd
added logging info
capri-xiyue Nov 19, 2025
1167364
resolve merge conflicts
capri-xiyue Nov 19, 2025
e33233b
change endpointpool struct
capri-xiyue Nov 20, 2025
86ac1f8
fixed variable naming
capri-xiyue Nov 20, 2025
11a8a68
fixed linter
capri-xiyue Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 116 additions & 30 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"net/http/pprof"
"os"
"runtime"
"strconv"
"strings"
"sync/atomic"

"github.com/go-logr/logr"
Expand All @@ -44,6 +46,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/metrics/filters"
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"

"k8s.io/apimachinery/pkg/labels"
"sigs.k8s.io/gateway-api-inference-extension/internal/runnable"
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
Expand Down Expand Up @@ -100,6 +103,8 @@ var (
poolName = flag.String("pool-name", runserver.DefaultPoolName, "Name of the InferencePool this Endpoint Picker is associated with.")
poolGroup = flag.String("pool-group", runserver.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.")
poolNamespace = flag.String("pool-namespace", "", "Namespace of the InferencePool this Endpoint Picker is associated with.")
endpointSelector = flag.String("endpoint-selector", "", "selector to filter model server pods on, only key value paris is supported. Format: a comma-separated list of key value paris, e.g., 'app:vllm-llama3-8b-instruct,env=prod'.")
endpointTargetPorts = flag.String("endpoint-target-ports", "", "target ports of model server pods. Format: a comma-separated list of numbers, e.g., '3000,3001,3002'")
logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity")
secureServing = flag.Bool("secure-serving", runserver.DefaultSecureServing, "Enables secure serving. Defaults to true.")
healthChecking = flag.Bool("health-checking", runserver.DefaultHealthChecking, "Enables health checking")
Expand Down Expand Up @@ -209,14 +214,20 @@ func (r *Runner) Run(ctx context.Context) error {
setupLog.Error(err, "Failed to get Kubernetes rest config")
return err
}
// Setup EndpointPool
endpointPool, err := setupEndpointPool(setupLog)
if err != nil {
setupLog.Error(err, "Failed to set up Endpoints Pool")
return err
}

// --- Setup Datastore ---
useDatalayerV2 := env.GetEnvBool(enableExperimentalDatalayerV2, false, setupLog)
epf, err := r.setupMetricsCollection(setupLog, useDatalayerV2)
if err != nil {
return err
}
datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort))
datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort), endpointPool)

// --- Setup Metrics Server ---
customCollectors := []prometheus.Collector{collectors.NewInferencePoolMetricsCollector(datastore)}
Expand All @@ -241,34 +252,10 @@ func (r *Runner) Run(ctx context.Context) error {
}(),
}

// Determine pool namespace: if --pool-namespace is non-empty, use it; else NAMESPACE env var; else default
resolvePoolNamespace := func() string {
if *poolNamespace != "" {
return *poolNamespace
}
if nsEnv := os.Getenv("NAMESPACE"); nsEnv != "" {
return nsEnv
}
return runserver.DefaultPoolNamespace
}
resolvedPoolNamespace := resolvePoolNamespace()
poolNamespacedName := types.NamespacedName{
Name: *poolName,
Namespace: resolvedPoolNamespace,
}
poolGroupKind := schema.GroupKind{
Group: *poolGroup,
Kind: "InferencePool",
}
poolGKNN := common.GKNN{
NamespacedName: poolNamespacedName,
GroupKind: poolGroupKind,
}

isLeader := &atomic.Bool{}
isLeader.Store(false)

mgr, err := runserver.NewDefaultManager(poolGKNN, cfg, metricsServerOptions, *haEnableLeaderElection)
mgr, err := runserver.NewDefaultManager(endpointPool, cfg, metricsServerOptions, *haEnableLeaderElection)
if err != nil {
setupLog.Error(err, "Failed to create controller manager")
return err
Expand Down Expand Up @@ -357,8 +344,7 @@ func (r *Runner) Run(ctx context.Context) error {
// --- Setup ExtProc Server Runner ---
serverRunner := &runserver.ExtProcServerRunner{
GrpcPort: *grpcPort,
PoolNamespacedName: poolNamespacedName,
PoolGKNN: poolGKNN,
EndpointPool: endpointPool,
Datastore: datastore,
SecureServing: *secureServing,
HealthChecking: *healthChecking,
Expand Down Expand Up @@ -396,6 +382,68 @@ func (r *Runner) Run(ctx context.Context) error {
return nil
}

func setupEndpointPool(setupLog logr.Logger) (*datalayer.EndpointPool, error) {
endpointPool := datalayer.NewEndpointPool(false, common.GKNN{})
if *poolName != "" {
// Determine pool namespace: if --pool-namespace is non-empty, use it; else NAMESPACE env var; else default
resolvePoolNamespace := func() string {
if *poolNamespace != "" {
return *poolNamespace
}
if nsEnv := os.Getenv("NAMESPACE"); nsEnv != "" {
return nsEnv
}
return runserver.DefaultPoolNamespace
}
resolvedPoolNamespace := resolvePoolNamespace()
poolNamespacedName := types.NamespacedName{
Name: *poolName,
Namespace: resolvedPoolNamespace,
}
poolGroupKind := schema.GroupKind{
Group: *poolGroup,
Kind: "InferencePool",
}
poolGKNN := common.GKNN{
NamespacedName: poolNamespacedName,
GroupKind: poolGroupKind,
}
endpointPool.GKNN = poolGKNN
}

if *endpointSelector != "" {
labelsMap, err := labels.ConvertSelectorToLabelsMap(*endpointSelector)
if err != nil {
setupLog.Error(err, "Failed to parse flag %q with error: %w", "endpoint-selector", err)
return nil, err
}
endpointPool.EndPoints.Selector = labelsMap
endpointPool.EndPoints.TargetPorts, err = strToUniqueIntSlice(*endpointTargetPorts)
if err != nil {
setupLog.Error(err, "Failed to parse flag %q with error: %w", "endpoint-target-ports", err)
}
endpointPool.DisableK8sCrd = true

// Determine EPP namespace: NAMESPACE env var; else default
eppNsEnv := os.Getenv("NAMESPACE")
if eppNsEnv == "" {
setupLog.Error(err, "Failed to get environment variable EPP_NAMESPACE")
}
// Determine EPP name: EPP_NAME env var
eppNameEnv := os.Getenv("EPP_NAME")
if eppNameEnv == "" {
setupLog.Error(err, "Failed to get environment variable EPP_NAME")

}
endpointPool.GKNN = common.GKNN{
NamespacedName: types.NamespacedName{Namespace: eppNsEnv, Name: eppNameEnv},
GroupKind: schema.GroupKind{Kind: "apps", Group: "Deployment"},
}

}
return endpointPool, nil
}

// registerInTreePlugins registers the factory functions of all known plugins
func (r *Runner) registerInTreePlugins() {
plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory)
Expand Down Expand Up @@ -575,9 +623,19 @@ func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds datastore.
}

func validateFlags() error {
if *poolName == "" {
return fmt.Errorf("required %q flag not set", "poolName")
if (*poolName != "" && *endpointSelector != "") || (*poolName == "" && *endpointSelector == "") {
return errors.New("either poolName or endpoint-selector must be set")
}
if *endpointSelector != "" {
targetPortsList, err := strToUniqueIntSlice(*endpointTargetPorts)
if err != nil {
return fmt.Errorf("unexpected value for %q flag with error %w", "endpoint-target-ports", err)
}
if len(targetPortsList) == 0 || len(targetPortsList) > 8 {
return fmt.Errorf("flag %q should have length from 1 to 8", "endpoint-target-ports")
}
}

if *configText != "" && *configFile != "" {
return fmt.Errorf("both the %q and %q flags can not be set at the same time", "configText", "configFile")
}
Expand All @@ -588,6 +646,34 @@ func validateFlags() error {
return nil
}

func strToUniqueIntSlice(s string) ([]int, error) {
seen := make(map[int]struct{})
var intList []int

if s == "" {
return intList, nil
}

strList := strings.Split(s, ",")

for _, str := range strList {
trimmedStr := strings.TrimSpace(str)
if trimmedStr == "" {
continue
}
portInt, err := strconv.Atoi(trimmedStr)
if err != nil {
return nil, fmt.Errorf("invalid number: '%s' is not an integer", trimmedStr)
}

if _, ok := seen[portInt]; !ok {
seen[portInt] = struct{}{}
intList = append(intList, portInt)
}
}
return intList, nil
}

func verifyMetricMapping(mapping backendmetrics.MetricMapping, logger logr.Logger) {
if mapping.TotalQueuedRequests == nil {
logger.Info("Not scraping metric: TotalQueuedRequests")
Expand Down
4 changes: 4 additions & 0 deletions config/charts/inferencepool/templates/epp-deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ spec:
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: EPP_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
{{- if .Values.inferenceExtension.tracing.enabled }}
- name: OTEL_SERVICE_NAME
value: "gateway-api-inference-extension"
Expand Down
6 changes: 3 additions & 3 deletions pkg/epp/backend/metrics/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func refreshPrometheusMetrics(logger logr.Logger, datastore datalayer.PoolInfo,
}

podTotalCount := len(podMetrics)
metrics.RecordInferencePoolAvgKVCache(pool.Name, kvCacheTotal/float64(podTotalCount))
metrics.RecordInferencePoolAvgQueueSize(pool.Name, float64(queueTotal/podTotalCount))
metrics.RecordInferencePoolReadyPods(pool.Name, float64(podTotalCount))
metrics.RecordInferencePoolAvgKVCache(pool.GKNN.Name, kvCacheTotal/float64(podTotalCount))
metrics.RecordInferencePoolAvgQueueSize(pool.GKNN.Name, float64(queueTotal/podTotalCount))
metrics.RecordInferencePoolReadyPods(pool.GKNN.Name, float64(podTotalCount))
}
5 changes: 2 additions & 3 deletions pkg/epp/backend/metrics/pod_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/types"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)

Expand Down Expand Up @@ -86,8 +85,8 @@ func TestMetricsRefresh(t *testing.T) {

type fakeDataStore struct{}

func (f *fakeDataStore) PoolGet() (*v1.InferencePool, error) {
return &v1.InferencePool{Spec: v1.InferencePoolSpec{TargetPorts: []v1.Port{{Number: 8000}}}}, nil
func (f *fakeDataStore) PoolGet() (*datalayer.EndpointPool, error) {
return &datalayer.EndpointPool{}, nil
}

func (f *fakeDataStore) PodList(func(PodMetrics) bool) []PodMetrics {
Expand Down
34 changes: 18 additions & 16 deletions pkg/epp/controller/inferenceobjective_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,24 @@ import (
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool"
utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
)

var (
pool = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef()
inferencePool = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef()
infObjective1 = utiltest.MakeInferenceObjective("model1").
Namespace(pool.Namespace).
Namespace(inferencePool.Namespace).
Priority(1).
CreationTimestamp(metav1.Unix(1000, 0)).
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.k8s.io").ObjRef()
infObjective1Pool2 = utiltest.MakeInferenceObjective(infObjective1.Name).
Namespace(infObjective1.Namespace).
Expand All @@ -57,24 +59,24 @@ var (
Namespace(infObjective1.Namespace).
Priority(2).
CreationTimestamp(metav1.Unix(1003, 0)).
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.k8s.io").ObjRef()
infObjective1Deleted = utiltest.MakeInferenceObjective(infObjective1.Name).
Namespace(infObjective1.Namespace).
CreationTimestamp(metav1.Unix(1004, 0)).
DeletionTimestamp().
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.k8s.io").ObjRef()
infObjective1DiffGroup = utiltest.MakeInferenceObjective(infObjective1.Name).
Namespace(pool.Namespace).
Namespace(inferencePool.Namespace).
Priority(1).
CreationTimestamp(metav1.Unix(1005, 0)).
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.x-k8s.io").ObjRef()
infObjective2 = utiltest.MakeInferenceObjective("model2").
Namespace(pool.Namespace).
Namespace(inferencePool.Namespace).
CreationTimestamp(metav1.Unix(1000, 0)).
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.k8s.io").ObjRef()
)

Expand Down Expand Up @@ -120,7 +122,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
{
name: "Objective not found, no matching existing objective to delete",
objectivessInStore: []*v1alpha2.InferenceObjective{infObjective1},
incomingReq: &types.NamespacedName{Name: "non-existent-objective", Namespace: pool.Namespace},
incomingReq: &types.NamespacedName{Name: "non-existent-objective", Namespace: inferencePool.Namespace},
wantObjectives: []*v1alpha2.InferenceObjective{infObjective1},
},
{
Expand Down Expand Up @@ -160,17 +162,18 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
WithObjects(initObjs...).
Build()
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
ds := datastore.NewDatastore(t.Context(), pmf, 0)
ds := datastore.NewDatastore(t.Context(), pmf, 0, datalayer.NewEndpointPool(false, pool.ToGKNN(inferencePool)))
for _, m := range test.objectivessInStore {
ds.ObjectiveSet(m)
}
_ = ds.PoolSet(context.Background(), fakeClient, pool)
endpointPool := pool.InferencePoolToEndpointPool(inferencePool)
_ = ds.PoolSet(context.Background(), fakeClient, endpointPool)
reconciler := &InferenceObjectiveReconciler{
Reader: fakeClient,
Datastore: ds,
PoolGKNN: common.GKNN{
NamespacedName: types.NamespacedName{Name: pool.Name, Namespace: pool.Namespace},
GroupKind: schema.GroupKind{Group: pool.GroupVersionKind().Group, Kind: pool.GroupVersionKind().Kind},
NamespacedName: types.NamespacedName{Name: inferencePool.Name, Namespace: inferencePool.Namespace},
GroupKind: schema.GroupKind{Group: inferencePool.GroupVersionKind().Group, Kind: inferencePool.GroupVersionKind().Kind},
},
}
if test.incomingReq == nil {
Expand All @@ -190,8 +193,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
if len(test.wantObjectives) != len(ds.ObjectiveGetAll()) {
t.Errorf("Unexpected; want: %d, got:%d", len(test.wantObjectives), len(ds.ObjectiveGetAll()))
}

if diff := diffStore(ds, diffStoreParams{wantPool: pool, wantObjectives: test.wantObjectives}); diff != "" {
if diff := diffStore(ds, diffStoreParams{wantPool: endpointPool, wantObjectives: test.wantObjectives}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}

Expand Down
Loading