Skip to content

Commit b2f9cb8

Browse files
Add latency predictor plugins, deployment, and runner.go integration
1 parent 3836d3b commit b2f9cb8

File tree

12 files changed

+769
-23
lines changed

12 files changed

+769
-23
lines changed

cmd/epp/runner/runner.go

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package runner
1919
import (
2020
"context"
2121
"crypto/tls"
22+
"encoding/json"
2223
"errors"
2324
"flag"
2425
"fmt"
@@ -64,13 +65,15 @@ import (
6465
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector"
6566
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
6667
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
68+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router"
6769
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
6870
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
6971
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer"
7072
testfilter "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/test/filter"
7173
runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server"
7274
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
7375
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
76+
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
7477
"sigs.k8s.io/gateway-api-inference-extension/version"
7578
)
7679

@@ -119,6 +122,7 @@ var (
119122
"then a self-signed certificate is used.")
120123
// metric flags
121124
totalQueuedRequestsMetric = flag.String("total-queued-requests-metric", runserver.DefaultTotalQueuedRequestsMetric, "Prometheus metric for the number of queued requests.")
125+
totalRunningRequestsMetric = flag.String("total-running-requests-metric", runserver.DefaultTotalRunningRequestsMetric, "Prometheus metric for the number of running requests.")
122126
kvCacheUsagePercentageMetric = flag.String("kv-cache-usage-percentage-metric", runserver.DefaultKvCacheUsagePercentageMetric, "Prometheus metric for the fraction of KV-cache blocks currently in use (from 0 to 1).")
123127
// LoRA metrics
124128
loraInfoMetric = flag.String("lora-info-metric", runserver.DefaultLoraInfoMetric, "Prometheus metric for the LoRA info metrics (must be in vLLM label format).")
@@ -138,7 +142,10 @@ var (
138142
modelServerMetricsScheme = flag.String("model-server-metrics-scheme", "http", "Scheme to scrape metrics from pods")
139143
modelServerMetricsHttpsInsecureSkipVerify = flag.Bool("model-server-metrics-https-insecure-skip-verify", true, "When using 'https' scheme for 'model-server-metrics-scheme', configure 'InsecureSkipVerify' (default to true)")
140144
haEnableLeaderElection = flag.Bool("ha-enable-leader-election", false, "Enables leader election for high availability. When enabled, readiness probes will only pass on the leader.")
141-
tracing = flag.Bool("tracing", true, "Enables emitting traces")
145+
146+
// Latency Predictor Flag
147+
enableLatencyPredictor = flag.Bool("enable-latency-predictor", false, "Enable the regression-based latency predictor and scheduler scorer.")
148+
tracing = flag.Bool("tracing", true, "Enables emitting traces")
142149

143150
setupLog = ctrl.Log.WithName("setup")
144151
)
@@ -315,6 +322,32 @@ func (r *Runner) Run(ctx context.Context) error {
315322
runtime.SetBlockProfileRate(1)
316323
}
317324

325+
// ===================================================================
326+
// == Latency Predictor Integration
327+
// ===================================================================
328+
var predictor latencypredictor.PredictorInterface // Use the interface type
329+
if *enableLatencyPredictor {
330+
setupLog.Info("Latency predictor is enabled. Initializing...")
331+
predictor = latencypredictor.New(latencypredictor.ConfigFromEnv(), ctrl.Log.WithName("latency-predictor"))
332+
333+
// For the runnable, you'll need to type assert back to the concrete type
334+
concretePredictor := predictor.(*latencypredictor.Predictor)
335+
if err := mgr.Add(runnable.NoLeaderElection(&predictorRunnable{predictor: concretePredictor})); err != nil {
336+
setupLog.Error(err, "Failed to register latency predictor runnable")
337+
return err
338+
}
339+
} else {
340+
setupLog.Info("Latency predictor is disabled.")
341+
predictor = nil // This will be a true nil interface
342+
}
343+
// ===================================================================
344+
345+
err = r.parsePluginsConfiguration(ctx, predictor, datastore)
346+
if err != nil {
347+
setupLog.Error(err, "Failed to parse the configuration")
348+
return err
349+
}
350+
318351
// --- Initialize Core EPP Components ---
319352
if r.schedulerConfig == nil {
320353
err := errors.New("scheduler config must be set either by config api or through code")
@@ -369,6 +402,7 @@ func (r *Runner) Run(ctx context.Context) error {
369402
Director: director,
370403
SaturationDetector: saturationDetector,
371404
UseExperimentalDatalayerV2: r.featureGates[datalayer.FeatureGate], // pluggable data layer feature flag
405+
LatencyPredictor: predictor,
372406
}
373407
if err := serverRunner.SetupWithManager(ctx, mgr); err != nil {
374408
setupLog.Error(err, "Failed to setup EPP controllers")
@@ -413,6 +447,13 @@ func (r *Runner) registerInTreePlugins() {
413447
plugins.Register(testresponsereceived.DestinationEndpointServedVerifierType, testresponsereceived.DestinationEndpointServedVerifierFactory)
414448
}
415449

450+
func (r *Runner) registerLatencyPredictorPlugins(predictor latencypredictor.PredictorInterface) {
451+
plugins.Register(slo_aware_router.SLOAwareRouterPluginType, func(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
452+
return slo_aware_router.NewSLOAwareRouter(predictor, slo_aware_router.HeadroomSelectionStrategy).WithName(name), nil
453+
})
454+
plugins.Register(profile.SLOAwareProfileHandlerType, profile.SLOAwareProfileHandlerFactory)
455+
}
456+
416457
func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.EndpointPickerConfig, error) {
417458
if *configText == "" && *configFile == "" {
418459
return nil, nil // configuring through code, not through file
@@ -435,6 +476,12 @@ func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.End
435476
loader.RegisterFeatureGate(flowcontrol.FeatureGate)
436477

437478
r.registerInTreePlugins()
479+
// If we have a latency predictor enabled and predictor and datastore are not nil,
480+
// register the latency predictor plugins (currently just the SLO scorer).
481+
if *enableLatencyPredictor && predictor != nil {
482+
setupLog.Info("Registering latency predictor plugins")
483+
r.registerLatencyPredictorPlugins(predictor)
484+
}
438485

439486
rawConfig, featureGates, err := loader.LoadConfigPhaseOne(configBytes, logger)
440487
if err != nil {
@@ -519,6 +566,7 @@ func (r *Runner) setupMetricsCollection(setupLog logr.Logger, useExperimentalDat
519566
func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) {
520567
mapping, err := backendmetrics.NewMetricMapping(
521568
*totalQueuedRequestsMetric,
569+
*totalRunningRequestsMetric,
522570
*kvCacheUsagePercentageMetric,
523571
*loraInfoMetric,
524572
*cacheInfoMetric,
@@ -567,6 +615,7 @@ func setupDatalayer(logger logr.Logger) (datalayer.EndpointFactory, error) {
567615
*modelServerMetricsHttpsInsecureSkipVerify,
568616
nil)
569617
extractor, err := dlmetrics.NewExtractor(*totalQueuedRequestsMetric,
618+
*totalRunningRequestsMetric,
570619
*kvCacheUsagePercentageMetric,
571620
*loraInfoMetric, *cacheInfoMetric)
572621

@@ -683,3 +732,21 @@ func setupPprofHandlers(mgr ctrl.Manager) error {
683732
}
684733
return nil
685734
}
735+
736+
// ===================================================================
737+
// == Latency Predictor Plugin and Helpers
738+
// ===================================================================
739+
740+
// predictorRunnable implements controller-runtime's Runnable interface to manage the predictor's lifecycle.
741+
type predictorRunnable struct {
742+
predictor *latencypredictor.Predictor
743+
}
744+
745+
func (p *predictorRunnable) Start(ctx context.Context) error {
746+
setupLog.Info("Starting latency predictor...")
747+
p.predictor.Start(ctx)
748+
<-ctx.Done()
749+
setupLog.Info("Stopping latency predictor...")
750+
p.predictor.Stop()
751+
return nil
752+
}

0 commit comments

Comments
 (0)