Skip to content
111 changes: 80 additions & 31 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/config/loader"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
dlmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors"
Expand Down Expand Up @@ -245,40 +247,11 @@ func (r *Runner) Run(ctx context.Context) error {
}

// --- Setup Datastore ---
mapping, err := backendmetrics.NewMetricMapping(
*totalQueuedRequestsMetric,
*kvCacheUsagePercentageMetric,
*loraInfoMetric,
)
epf, err := r.setupMetricsCollection(setupLog)
if err != nil {
setupLog.Error(err, "Failed to create metric mapping from flags.")
return err
}
verifyMetricMapping(*mapping, setupLog)

var metricsHttpClient *http.Client
if *modelServerMetricsScheme == "https" {
metricsHttpClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: *modelServerMetricsHttpsInsecureSkipVerify,
},
},
}
} else {
metricsHttpClient = http.DefaultClient
}

pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{
MetricMapping: mapping,
ModelServerMetricsPort: int32(*modelServerMetricsPort),
ModelServerMetricsPath: *modelServerMetricsPath,
ModelServerMetricsScheme: *modelServerMetricsScheme,
Client: metricsHttpClient,
},
*refreshMetricsInterval)

datastore := datastore.NewDatastore(ctx, pmf)
datastore := datastore.NewDatastore(ctx, epf)

// --- Setup Metrics Server ---
customCollectors := []prometheus.Collector{collectors.NewInferencePoolMetricsCollector(datastore)}
Expand Down Expand Up @@ -446,6 +419,82 @@ func (r *Runner) parsePluginsConfiguration(ctx context.Context) error {
return nil
}

func (r *Runner) setupMetricsCollection(setupLog logr.Logger) (datalayer.EndpointFactory, error) {
if datalayer.Enabled(setupLog) {
return setupDatalayer(setupLog)
}

if len(datalayer.GetSources()) != 0 {
setupLog.Info("data sources registered but pluggable datalayer is disabled")
}
return setupMetricsV1(setupLog)
}

func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) {
mapping, err := backendmetrics.NewMetricMapping(
*totalQueuedRequestsMetric,
*kvCacheUsagePercentageMetric,
*loraInfoMetric,
)
if err != nil {
setupLog.Error(err, "Failed to create metric mapping from flags.")
return nil, err
}
verifyMetricMapping(*mapping, setupLog)

var metricsHttpClient *http.Client
if *modelServerMetricsScheme == "https" {
metricsHttpClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: *modelServerMetricsHttpsInsecureSkipVerify,
},
},
}
} else {
metricsHttpClient = http.DefaultClient
}

pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{
MetricMapping: mapping,
ModelServerMetricsPort: int32(*modelServerMetricsPort),
ModelServerMetricsPath: *modelServerMetricsPath,
ModelServerMetricsScheme: *modelServerMetricsScheme,
Client: metricsHttpClient,
},
*refreshMetricsInterval)
return pmf, nil
}

func setupDatalayer(setupLog logr.Logger) (datalayer.EndpointFactory, error) {
// create and register a metrics data source and extractor. In the future,
// data sources and extractors might be configured via a file. Once done,
// this (and registering the sources with the endpoint factory) should
// be moved accordingly.
source := dlmetrics.NewDataSource(*modelServerMetricsScheme,
int32(*modelServerMetricsPort),
*modelServerMetricsPath,
*modelServerMetricsHttpsInsecureSkipVerify,
nil)
extractor, err := dlmetrics.NewExtractor(*totalQueuedRequestsMetric,
*kvCacheUsagePercentageMetric,
*loraInfoMetric)

if err != nil {
return nil, err
}
if err := source.AddExtractor(extractor); err != nil {
return nil, err
}
if err := datalayer.RegisterSource(source); err != nil {
return nil, err
}

factory := datalayer.NewEndpointFactory(setupLog, *refreshMetricsInterval)
factory.SetSources(datalayer.GetSources())
return factory, nil
}

func initLogging(opts *zap.Options) {
// Unless -zap-log-level is explicitly set, use -v
useV := true
Expand Down
6 changes: 6 additions & 0 deletions pkg/epp/datalayer/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"sync"
"time"

"github.com/go-logr/logr"
"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand Down Expand Up @@ -73,6 +74,8 @@ type Collector struct {
startOnce sync.Once
stopOnce sync.Once

logger logr.Logger

// TODO: optional metrics tracking collection (e.g., errors, invocations, ...)
}

Expand All @@ -82,11 +85,13 @@ func NewCollector() *Collector {
}

// Start initiates data source collection for the endpoint.
// TODO: pass PoolInfo for backward compatibility
func (c *Collector) Start(ctx context.Context, ticker Ticker, ep Endpoint, sources []DataSource) error {
var ready chan struct{}
started := false

c.startOnce.Do(func() {
c.logger = log.FromContext(ctx)
c.ctx, c.cancel = context.WithCancel(ctx)
started = true
ready = make(chan struct{})
Expand All @@ -107,6 +112,7 @@ func (c *Collector) Start(ctx context.Context, ticker Ticker, ep Endpoint, sourc
case <-c.ctx.Done(): // per endpoint context cancelled
return
case <-ticker.Channel():
// TODO: do not collect if there's no pool specified?
for _, src := range sources {
ctx, cancel := context.WithTimeout(c.ctx, defaultCollectionTimeout)
_ = src.Collect(ctx, endpoint) // TODO: track errors per collector?
Expand Down
31 changes: 31 additions & 0 deletions pkg/epp/datalayer/enabled.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package datalayer

import (
"github.com/go-logr/logr"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
)

const (
EnableExperimentalDatalayerV2 = "ENABLE_EXPERIMENTAL_DATALAYER_V2"
)

func Enabled(logger logr.Logger) bool {
return env.GetEnvBool(EnableExperimentalDatalayerV2, false, logger)
}
85 changes: 85 additions & 0 deletions pkg/epp/datalayer/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ package datalayer

import (
"context"
"sync"
"time"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"

"sigs.k8s.io/controller-runtime/pkg/log"
v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
)

Expand All @@ -44,3 +49,83 @@ type EndpointFactory interface {
NewEndpoint(parent context.Context, inpod *corev1.Pod, poolinfo PoolInfo) Endpoint
ReleaseEndpoint(ep Endpoint)
}

// EndpointLifecycle manages the life cycle (creation and termination) of
// endpoints.
type EndpointLifecycle struct {
sources []DataSource // data sources for collectors
collectors sync.Map // collectors map. key: Pod namespaced name, value: *Collector
refreshInterval time.Duration // metrics refresh interval
}

// NewEndpointFactory returns a new endpoint for factory, managing collectors for
// its endpoints.
// TODO: consider making a config object? Only caveat is that sources might not be
// known at creation time (e.g., loaded from configuration file).
func NewEndpointFactory(log logr.Logger, refreshMetricsInterval time.Duration) *EndpointLifecycle {
return &EndpointLifecycle{
sources: []DataSource{},
collectors: sync.Map{},
refreshInterval: refreshMetricsInterval,
}
}

// NewEndpoint implements EndpointFactory.NewEndpoint.
// Creates a new endpoint and starts its associated collector with its own ticker.
// Guards against multiple concurrent calls for the same endpoint.
func (lc *EndpointLifecycle) NewEndpoint(parent context.Context, inpod *corev1.Pod, _ PoolInfo) Endpoint {
key := types.NamespacedName{Namespace: inpod.Namespace, Name: inpod.Name}
logger := log.FromContext(parent).WithValues("pod", key)

if _, ok := lc.collectors.Load(key); ok {
logger.Info("collector already running for endpoint", "endpoint", key)
return nil
}

endpoint := NewEndpoint()
endpoint.UpdatePod(inpod)
collector := NewCollector() // for full backward compatibility, set the logger and poolinfo

if _, loaded := lc.collectors.LoadOrStore(key, collector); loaded {
// another goroutine already created and stored a collector for this endpoint.
// No need to start the new collector.
logger.Info("collector already running for endpoint", "endpoint", key)
return nil
}

ticker := NewTimeTicker(lc.refreshInterval)
if err := collector.Start(parent, ticker, endpoint, lc.sources); err != nil {
logger.Error(err, "failed to start collector for endpoint", "endpoint", key)
lc.collectors.Delete(key)
}

return endpoint
}

// ReleaseEndpoint implements EndpointFactory.ReleaseEndpoint
// Stops the collector and cleans up resources for the endpoint
func (lc *EndpointLifecycle) ReleaseEndpoint(ep Endpoint) {
key := ep.GetPod().GetNamespacedName()

if value, ok := lc.collectors.LoadAndDelete(key); ok {
collector := value.(*Collector)
_ = collector.Stop()
}
}

// Shutdown gracefully stops all collectors and cleans up all resources.
func (lc *EndpointLifecycle) Shutdown() {
lc.collectors.Range(func(key, value interface{}) bool {
collector := value.(*Collector)
_ = collector.Stop()
lc.collectors.Delete(key)
return true
})
}

// SetSources configures the data sources available for collection.
// This should be called after all data sources have been configured and before
// any endpoint is created.
func (lc *EndpointLifecycle) SetSources(sources []DataSource) {
_ = copy(lc.sources, sources)
}
13 changes: 7 additions & 6 deletions pkg/epp/datalayer/metrics/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ const (
)

var (
baseTransport = &http.Transport{
MaxIdleConns: maxIdleConnections,
MaxIdleConnsPerHost: 4, // host is defined as scheme://host:port
// TODO: set additional timeouts, transport options, etc.
}
defaultClient = &client{
Client: http.Client{
Timeout: timeout,
Transport: &http.Transport{
MaxIdleConns: maxIdleConnections,
MaxIdleConnsPerHost: 4, // host is defined as scheme://host:port
},
// TODO: set additional timeouts, transport options, etc.
Timeout: timeout,
Transport: baseTransport,
},
}
)
Expand Down
15 changes: 12 additions & 3 deletions pkg/epp/datalayer/metrics/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package metrics

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
Expand All @@ -30,7 +31,7 @@ import (
)

const (
dataSourceName = "metrics-data-source"
DataSourceName = "metrics-data-source"
)

// DataSource is a Model Server Protocol (MSP) compliant metrics data source,
Expand All @@ -46,7 +47,15 @@ type DataSource struct {

// NewDataSource returns a new MSP compliant metrics data source, configured with the provided
// client factory. If ClientFactory is nil, a default factory is used.
func NewDataSource(metricsScheme string, metricsPort int32, metricsPath string, cl Client) *DataSource {
func NewDataSource(metricsScheme string, metricsPort int32, metricsPath string, skipCertVerification bool, cl Client) *DataSource {
if metricsScheme == "https" {
httpsTransport := baseTransport.Clone()
httpsTransport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: skipCertVerification,
}
defaultClient.Transport = httpsTransport
}

if cl == nil {
cl = defaultClient
}
Expand All @@ -68,7 +77,7 @@ func (dataSrc *DataSource) SetPort(metricsPort int32) {

// Name returns the metrics data source name.
func (dataSrc *DataSource) Name() string {
return dataSourceName
return DataSourceName
}

// AddExtractor adds an extractor to the data source, validating it can process
Expand Down
Loading